diff --git a/CMakeLists.txt b/CMakeLists.txt index aa2d430cca16f412e7d9e6dcfce6638b4f8bb61c..5486b6f98df2ad32a75bece4eb49885c671fc9a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,4 @@ add_subdirectory(paddle_inference) endif() add_subdirectory(python) -set(PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE}) -set(PYTHON_LIBRARIES ${PYTHON_LIB}) - #add_subdirectory(examples) diff --git a/core/cube/cube-agent/src/CMakeLists.txt b/core/cube/cube-agent/src/CMakeLists.txt deleted file mode 100644 index eb192f0fd14969e9f25a71a0ba968ea244bca830..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -set(SOURCE_FILE cube-agent.go) -add_go_executable(cube-agent ${SOURCE_FILE}) -add_dependencies(cube-agent agent-docopt-go) -add_dependencies(cube-agent agent-logex) -add_dependencies(cube-agent agent-pipeline) diff --git a/core/cube/cube-agent/src/agent/define.go b/core/cube/cube-agent/src/agent/define.go deleted file mode 100644 index 1b602b6fc0f2de86325f49ebdeff5b3321bb697a..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/define.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "errors" - _ "github.com/Badangel/logex" - "strings" - "sync" -) - -var ( - Dir string - WorkerNum int - QueueCapacity int32 - MasterHost []string - MasterPort []string - - TestHostname string - TestIdc string - ShardLock sync.RWMutex - - CmdWorkPool *WorkPool - CmdWorkFilter sync.Map -) - -type ( - Status struct { - Status string `json:"status"` - Version string `json:"version"` - } - - MasterResp struct { - Success string `json:"success"` - Message string `json:"message"` - Data string `json:"data"` - } - - ShardInfo struct { - DictName string - ShardSeq int - SlotIdList string - DataDir string - Service string `json:"service,omitempty"` - Libcube string `json:"libcube,omitempty"` - } - - CubeResp struct { - Status int `json:"status"` - CurVersion string `json:"cur_version"` - BgVersion string `json:"bg_version"` - } -) - -var BUILTIN_STATUS = Status{"RUNNING", "3.0.0.1"} - -var ShardInfoMap map[string]map[string]*ShardInfo -var disks []string - -func GetMaster(master string) (host, port string, err error) { - if len(ShardInfoMap) < 1 { - return "", "", errors.New("empty master list.") - } - if master == "" { - return MasterHost[0], MasterPort[0], nil - } - if _, ok := ShardInfoMap[master]; ok { - m := strings.Split(master, ":") - if len(m) != 2 { - return MasterHost[0], MasterPort[0], nil - } - return m[0], m[1], nil - } else { - return MasterHost[0], MasterPort[0], nil - } -} diff --git a/core/cube/cube-agent/src/agent/http.go b/core/cube/cube-agent/src/agent/http.go deleted file mode 100755 index d548c5c4d424d3ebc2a3fbb141d79f6656e3c58b..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/http.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "bytes" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/Badangel/logex" -) - -type handlerFunc func(subpath string, m map[string]string, b []byte) (string, string, error) - -var ( // key = subpath; eg: path="/checker/job", key="job" - getHandler map[string]handlerFunc - putHandler map[string]handlerFunc - deleteHandler map[string]handlerFunc - postHandler map[string]handlerFunc -) - -func StartHttp(addr string) error { - - // init handlers: - initGetHandlers() - initPostHandlers() - - http.HandleFunc("/agent/", handleRest) - logex.Notice("start http ", addr) - return http.ListenAndServe(addr, nil) -} - -func handleRest(w http.ResponseWriter, r *http.Request) { - var ( - req_log string - status int32 - ) - time_begin := time.Now() - - cont_type := make([]string, 1, 1) - cont_type[0] = "application/json" - header := w.Header() - header["Content-Type"] = cont_type - w.Header().Add("Access-Control-Allow-Origin", "*") - - m := parseHttpKv(r) - b, _ := ioutil.ReadAll(r.Body) - - req_log = fmt.Sprintf("handle %v %v %v from %v, len(m)=%v, m=%+v", - r.Method, r.URL.Path, r.URL.RawQuery, r.RemoteAddr, len(m), m) - - api := r.URL.Path - - var showHandler map[string]handlerFunc - switch r.Method { - case "GET": - showHandler = getHandler - case "POST": // create - showHandler = postHandler - case "PUT": // update - showHandler = putHandler - case "DELETE": - showHandler = deleteHandler - default: - logex.Warningf(`{"error":1, "message":"unsupport method %v"}`, r.Method) - } - - handler, ok := showHandler[api] - - if !ok { - key_list := make([]string, 0, len(showHandler)) - for key := range showHandler { - key_list = append(key_list, key) - } - status = 2 - fmt.Fprintf(w, `{"success":"%v", "message":"wrong api", "method":"%s", "api":"%s", "api_list":"%v"}`, - status, r.Method, api, key_list) - - logex.Noticef(`%v, time=%v, status=%v`, - req_log, time.Now().Sub(time_begin).Nanoseconds()/1000000, status) - return - } - - var s string - rst, handle_log, err := handler(api, m, b) - if err == nil { - status = 0 - s = fmt.Sprintf(`{"success":"%v", "message":"query ok", "data":%s}`, status, rst) - } else { - status = 255 - s = fmt.Sprintf(`{"success":"%v", "message":%v, "data":%s}`, - status, quote(err.Error()), rst) - } - - if isJsonDict(s) { - fmt.Fprintln(w, s) - } else { - logex.Fatalf("invalid json: %v", s) - } - - if err == nil { - logex.Noticef(`%v, time=%v, status=%v, handle_log=%v`, - req_log, time.Now().Sub(time_begin).Nanoseconds()/1000000, - status, quote(handle_log)) - } else { - logex.Noticef(`%v, time=%v, status=%v, err=%v, handle_log=%v`, - req_log, time.Now().Sub(time_begin).Nanoseconds()/1000000, - status, quote(err.Error()), quote(handle_log)) - } -} - -func parseHttpKv(r *http.Request) map[string]string { - r.ParseForm() - m := make(map[string]string) - for k, v := range r.Form { - switch k { - case "user": // remove @baidu.com for user - m[k] = strings.Split(v[0], "@")[0] - default: - m[k] = v[0] - } - } - - // allow passing hostname for debug - if _, ok := m["hostname"]; !ok { - ip := r.RemoteAddr[:strings.Index(r.RemoteAddr, ":")] - m["hostname"], _ = getHostname(ip) - } - return m -} - -// restReq sends a restful request to requrl and returns response body. -func restReq(method, requrl string, timeout int, kv *map[string]string) (string, error) { - logex.Debug("####restReq####") - logex.Debug(*kv) - data := url.Values{} - if kv != nil { - for k, v := range *kv { - logex.Trace("req set:", k, v) - data.Set(k, v) - } - } - if method == "GET" || method == "DELETE" { - requrl = requrl + "?" + data.Encode() - data = url.Values{} - } - - logex.Notice(method, requrl) - req, err := http.NewRequest(method, requrl, bytes.NewBufferString(data.Encode())) - if err != nil { - logex.Warning("NewRequest failed:", err) - return "", err - } - if method == "POST" || method == "PUT" { - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) - } - - client := &http.Client{} - client.Timeout = time.Duration(timeout) * time.Second - resp, err := client.Do(req) - if err != nil { - logex.Warning("Do failed:", err) - return "", err - } - if resp.StatusCode < 200 || resp.StatusCode > 299 { - logex.Warning("resp status: " + resp.Status) - return "", errors.New("resp status: " + resp.Status) - } - - body, err := ioutil.ReadAll(resp.Body) - return string(body), err -} diff --git a/core/cube/cube-agent/src/agent/http_get.go b/core/cube/cube-agent/src/agent/http_get.go deleted file mode 100755 index 86e394372f666d7ecce8b17ac9c66c22c824aad5..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/http_get.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "encoding/json" - "fmt" -) - -func initGetHandlers() { - getHandler = map[string]handlerFunc{ - "/agent/status": GetStatus, - } -} - -func GetStatus(subpath string, m map[string]string, b []byte) (string, string, error) { - b, err := json.Marshal(BUILTIN_STATUS) - if err != nil { - return quote(""), "", fmt.Errorf("json marshal failed, %v", err) - } - - return string(b), "", err -} diff --git a/core/cube/cube-agent/src/agent/http_post.go b/core/cube/cube-agent/src/agent/http_post.go deleted file mode 100755 index 921da6e7d6ef575ecbc4a22f159f54177546e18b..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/http_post.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "encoding/json" - "fmt" - "github.com/Badangel/logex" -) - -func initPostHandlers() { - postHandler = map[string]handlerFunc{ - "/agent/cmd": PostCmd, - } -} - -func PostCmd(subpath string, m map[string]string, b []byte) (string, string, error) { - var work Work - err := json.Unmarshal(b, &work) - if err != nil { - logex.Warningf("Unmarshal from %s error (+%v)", string(b), err) - return quote(""), "", fmt.Errorf("Work json unmarshal work failed, %v", err) - } - - if _, ok := CmdWorkFilter.Load(work.Token()); ok { - logex.Warningf("Another work with same token is doing. Token(%s)", work.Token()) - return quote(""), "", fmt.Errorf("Another work with same key is doing.", err) - } - - CmdWorkFilter.Store(work.Token(), true) - err = work.DoWork() - CmdWorkFilter.Delete(work.Token()) - if err != nil { - return quote(""), "", fmt.Errorf("Do work failed.", err) - } - - return quote(""), "", err -} diff --git a/core/cube/cube-agent/src/agent/util.go b/core/cube/cube-agent/src/agent/util.go deleted file mode 100644 index 29d27682a3c2e1c46d7ca8cb71de53c2e95df71f..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/util.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/Badangel/logex" -) - -// restReq sends a restful request to requrl and returns response body. -func RestReq(method, requrl string, timeout int, kv *map[string]string) (string, error) { - data := url.Values{} - if kv != nil { - for k, v := range *kv { - //logex.Trace("req set:", k, v) - data.Set(k, v) - } - } - if method == "GET" || method == "DELETE" { - requrl = requrl + "?" + data.Encode() - data = url.Values{} - } - - //logex.Notice(method, requrl) - req, err := http.NewRequest(method, requrl, bytes.NewBufferString(data.Encode())) - if err != nil { - logex.Warning("NewRequest failed:", err) - return "", err - } - if method == "POST" || method == "PUT" { - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) - } - - client := &http.Client{} - client.Timeout = time.Duration(timeout) * time.Second - resp, err := client.Do(req) - if err != nil { - logex.Warning("Do failed:", err) - return "", err - } - if resp.StatusCode < 200 || resp.StatusCode > 299 { - logex.Warning("resp status: " + resp.Status) - return "", errors.New("resp status: " + resp.Status) - } - - body, err := ioutil.ReadAll(resp.Body) - return string(body), err -} - -// restReq sends a restful request to requrl and returns response body as json. -func JsonReq(method, requrl string, timeout int, kv *map[string]string, - out interface{}) error { - s, err := RestReq(method, requrl, timeout, kv) - logex.Debugf("json request method:[%v], requrl:[%s], timeout:[%v], map[%v], out_str:[%s]", method, requrl, timeout, kv, s) - if err != nil { - return err - } - return json.Unmarshal([]byte(s), out) -} - -func GetHdfsMeta(src string) (master, ugi, path string, err error) { - //src = "hdfs://root:rootpasst@st1-inf-platform0.st01.baidu.com:54310/user/mis_user/news_dnn_ctr_cube_1/1501836820/news_dnn_ctr_cube_1_part54.tar" - //src = "hdfs://st1-inf-platform0.st01.baidu.com:54310/user/mis_user/news_dnn_ctr_cube_1/1501836820/news_dnn_ctr_cube_1_part54.tar" - - ugiBegin := strings.Index(src, "//") - ugiPos := strings.LastIndex(src, "@") - if ugiPos != -1 && ugiBegin != -1 { - ugi = src[ugiBegin+2 : ugiPos] - } - src1 := strings.Replace(strings.Replace(src, "hdfs://", "", 1), ugi, "", 1) - if ugi != "" { - src1 = src1[1:] - } - pos := strings.Index(src1, "/") - if pos != -1 { - master = src1[0:pos] - path = src1[pos:] - } else { - logex.Warningf("failed to get the master or path for (%s)", src) - err = errors.New("invalid master or path found") - } - logex.Debugf("parse the (%s) succ, master is %s, ugi is (%s), path is %s", src, master, ugi, path) - return -} - -func getHostIp() (string, error) { - if addrs, err := net.InterfaceAddrs(); err == nil { - for _, addr := range addrs { - ips := addr.String() - logex.Debugf("get host ip: %v", ips) - if strings.HasPrefix(ips, "127") { - continue - } else { - list := strings.Split(ips, "/") - if len(list) != 2 { - continue - } - return list[0], nil - } - } - } - return "unkown ip", errors.New("get host ip failed") -} - -func getHostname(ip string) (hostname string, err error) { - if hostnames, err := net.LookupAddr(ip); err != nil { - hostname = ip - //logex.Warningf("cannot find the hostname of ip (%s), error (%v)", ip, err) - } else { - if len(hostnames) > 0 { - hostname = hostnames[0] - } else { - hostname = ip - } - } - - return hostname, err -} - -func GetLocalHostname() (hostname string, err error) { - if ip, err := getHostIp(); err == nil { - return getHostname(ip) - } else { - return "unkown ip", err - } -} - -func GetLocalHostnameCmd() (hostname string, err error) { - cmd := "hostname" - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if stdout != "" && err == nil { - hostname := strings.TrimSpace(stdout) - index := strings.LastIndex(hostname, ".baidu.com") - if index > 0 { - return hostname[:strings.LastIndex(hostname, ".baidu.com")], nil - } else { - return hostname, nil - } - } else { - logex.Debugf("using hostname cmd failed. err:%v", err) - return GetLocalHostname() - } -} - -// quote quotes string for json output. eg: s="123", quote(s)=`"123"` -func quote(s string) string { - return fmt.Sprintf("%q", s) -} - -// quoteb quotes byte array for json output. -func quoteb(b []byte) string { - return quote(string(b)) -} - -// quotea quotes string array for json output -func quotea(a []string) string { - b, _ := json.Marshal(a) - return string(b) -} - -func isJsonDict(s string) bool { - var js map[string]interface{} - return json.Unmarshal([]byte(s), &js) == nil -} diff --git a/core/cube/cube-agent/src/agent/work.go b/core/cube/cube-agent/src/agent/work.go deleted file mode 100644 index 8fdd90a52b877e2c7624a6ce48e13f7b0c2336c5..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/work.go +++ /dev/null @@ -1,883 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "crypto/md5" - "encoding/json" - "errors" - "fmt" - "github.com/Badangel/logex" - "github.com/Badangel/pipeline" - "os" - "os/exec" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" -) - -const ( - COMMAND_DOWNLOAD = "download" - COMMAND_RELOAD = "reload" - COMMAND_SWITCH = "enable" - COMMAND_ROLLBACK = "rollback" - COMMAND_CHECK = "check" - COMMAND_CLEAR = "clear" - COMMAND_POP = "pop" - - RETRY_TIMES = 3 - REQUEST_MASTER_TIMEOUT_SECOND = 60 - MAX_DOWN_CO = 7 - - RELOAD_RETRY_TIMES = 3 - RELOAD_RETRY_INTERVAL_SECOND = 10 - - DOWNLOAD_DONE_MARK_FILE = ".download_done" -) - -type VersionInfo struct { - Version string - Depend string - Source string -} - -type Work struct { - DictName string `json:"dict_name"` - ShardSeq int `json:"shard_seq"` - DeployPath string `json:"deploy_path"` - Command string `json:"command"` - Version string `json:"version"` - Depend string `json:"depend"` - Source string `json:"source"` - Mode string `json:"mode"` - DictMode string `json:"dict_mode"` - Port string `json:"port"` - bRollback bool `json:"b_rollback"` - RollbackInfo []VersionInfo `json:"rollback_info"` - Status string `json:"status"` - FinishStatus string `json:"finish_status"` - Service string `json:"service,omitempty"` - VersionSign string `json:"version_sign,omitempty"` - MasterAddress string `json:"master_address,omitempty"` - ActiveVersionList string `json:"active_version_list,omitempty"` -} - -func (work *Work) Token() string { - return work.DictName + strconv.Itoa(work.ShardSeq) + work.Service -} - -func (work *Work) Valid() bool { - if work.Command == "" || - work.Version == "" || - work.Depend == "" { - return false - } - return true -} - -func (work *Work) DoWork() error { - var err error - if !work.Valid() { - err = errors.New("Work is invalid") - return err - } - switch work.Command { - case COMMAND_DOWNLOAD: - err = work.Download() - case COMMAND_RELOAD: - err = work.Reload() - case COMMAND_SWITCH: - err = work.Enable() - case COMMAND_CHECK: - err = work.Check() - case COMMAND_CLEAR: - err = work.Clear() - case COMMAND_POP: - err = work.Pop() - default: - logex.Debugf("Invalid command %s received", work.Command) - err = errors.New("Invalid command.") - } - return err -} - -func GetDownloadDirs(dictName, service, version, depend, deployPath string, shardSeq, - split int) ([]string, error) { - dirs := make([]string, 0, split) - if deployPath == "" { - return dirs, errors.New("Invalid Deploy path") - } - parentDir := getParentDir(version, depend) - if split < 2 { - disk := path.Join(deployPath, "cube_data") - if service == "" { - dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), parentDir)) - } else { - dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), parentDir+"-"+dictName)) - } - } else { - for i := 0; i < split; i++ { - disk := path.Join(deployPath, "cube_data") - if service == "" { - dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), strconv.Itoa(i), parentDir)) - } else { - dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), - parentDir+"-"+dictName)) - } - } - } - - return dirs, nil -} - -func GetDataLinkDirs(dictName, service, version, depend, deployPath string, shardSeq, - split int) []string { - dirs := make([]string, 0, split) - parentDir := getParentDir(version, depend) - if split < 2 { - disk := path.Join(deployPath, "data") - if service == "" { - dirs = append(dirs, path.Join(disk, parentDir)) - } - } else { - for i := 0; i < split; i++ { - disk := path.Join(deployPath, "data") - if service == "" { - dirs = append(dirs, path.Join(disk, strconv.Itoa(i), parentDir)) - } - } - } - return dirs -} - -func (work *Work) Download() (err error) { - err = DoDownload(work.DictName, work.Service, work.Version, work.Depend, work.Mode, work.Source, - work.DeployPath, work.ShardSeq) - - if err != nil { - logex.Warningf("download error, failed to download %s, dir is %s, error is (+%v)", work.Source, work.DeployPath, err) - return - } - - if work.Service == "" { - err = UnTar(work.DictName, work.Service, work.Version, work.Depend, work.Source, - work.DeployPath, work.ShardSeq) - - if err == nil { - dataPath := path.Join(work.DeployPath, "data") - - // remove all old links - if work.Mode == "base" || len(work.RollbackInfo) != 0 { - cmd := fmt.Sprintf("ls -l %s | grep -E 'data.|index.' | awk '{print $9}'", dataPath) - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if err == nil && stdout != "" { - fileNameLi := strings.Split(strings.TrimSpace(stdout), "\n") - for _, file := range fileNameLi { - err = os.Remove(path.Join(dataPath, file)) - logex.Debugf("os.Remove(%s) error (%+v) ", path.Join(dataPath, file), err) - } - } - } - - // create symbolic link to the version rollbacked - err = CreateSymlink(work.DictName, work.Service, work.Version, work.Depend, dataPath, - work.DeployPath, work.ShardSeq, len(strings.Split(work.Source, ";"))) - } else { - logex.Warningf("download error, failed to untar for %s, dir is %s, error is (+%v)", work.Source, work.DeployPath, err) - } - } - - if err == nil { - // clear history data - work.clearData() - work.clearLink() - } else { - logex.Warningf("create symlink failed, error is (+%v)", err) - } - - return -} - -func (work *Work) clearData() (err error) { - split := len(strings.Split(work.Source, ";")) - downloadDirs, err := GetDownloadDirs(work.DictName, work.Service, work.Version, work.Depend, - work.DeployPath, work.ShardSeq, split) - if err != nil { - logex.Warningf("clearData failed, error is (+%v)", err) - return - } - for _, downloadDir := range downloadDirs { - parentDir, _ := filepath.Split(downloadDir) - - cmd := fmt.Sprintf("ls -l %s | grep -v %s | awk '{print $9}'", parentDir, work.Depend) - - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if err != nil || stdout == "" || work.Service != "" { - cmd = fmt.Sprintf("find %s -type d -ctime +1 -print | xargs -i rm -rf {}", parentDir) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } else { - rmList := strings.Split(strings.TrimSpace(stdout), "\n") - for i := 0; i < len(rmList); i++ { - if rmList[i] == "" { - continue - } - cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i]) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } - } - } - - return -} - -func (work *Work) clearPatchData() (err error) { - if work.Service != "" { - return - } - split := len(strings.Split(work.Source, ";")) - downloadDirs, err := GetDownloadDirs(work.DictName, work.Service, work.Version, work.Depend, - work.DeployPath, work.ShardSeq, split) - if err != nil { - logex.Warningf("clearPatchData failed, error is (+%v)", err) - return - } - for _, downloadDir := range downloadDirs { - parentDir, _ := filepath.Split(downloadDir) - cmd := fmt.Sprintf("ls -l %s | grep %s_ | awk '{print $9}'", parentDir, work.Depend) - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if err == nil && stdout != "" { - rmList := strings.Split(strings.TrimSpace(stdout), "\n") - for i := 0; i < len(rmList); i++ { - if rmList[i] == "" { - continue - } - cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i]) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } - } - } - - return -} - -func (work *Work) clearLink() (err error) { - if work.Service != "" { - return - } - split := len(strings.Split(work.Source, ";")) - dataLinkDirs := GetDataLinkDirs(work.DictName, work.Service, work.Version, work.Depend, - work.DeployPath, work.ShardSeq, split) - for _, linkDir := range dataLinkDirs { - parentDir, _ := filepath.Split(linkDir) - cmd := fmt.Sprintf("ls -l %s | grep -v %s | awk '{print $9}'", parentDir, work.Depend) - - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if err != nil || stdout == "" { - cmd = fmt.Sprintf("find %s -type d -ctime +1 -print | xargs -i rm -rf {}", parentDir) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } else { - rmList := strings.Split(strings.TrimSpace(stdout), "\n") - for i := 0; i < len(rmList); i++ { - if rmList[i] == "" { - continue - } - cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i]) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } - } - } - - return -} - -func (work *Work) clearPatchLink() (err error) { - if work.Service != "" { - return - } - split := len(strings.Split(work.Source, ";")) - dataLinkDirs := GetDataLinkDirs(work.DictName, work.Service, work.Version, work.Depend, - work.DeployPath, work.ShardSeq, split) - for _, linkDir := range dataLinkDirs { - parentDir, _ := filepath.Split(linkDir) - cmd := fmt.Sprintf("ls -l %s | grep %s_ | awk '{print $9}'", parentDir, work.Depend) - - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if err == nil && stdout != "" { - rmList := strings.Split(strings.TrimSpace(stdout), "\n") - for i := 0; i < len(rmList); i++ { - if rmList[i] == "" { - continue - } - cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i]) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } - } - } - - return -} - -func UnTar(dictName, service, version, depend, source, deployPath string, shardSeq int) (err error) { - sources := strings.Split(source, ";") - downloadDirs, err := GetDownloadDirs(dictName, service, version, depend, deployPath, shardSeq, - len(sources)) - if err != nil { - logex.Warningf("UnTar failed, error is (+%v)", err) - return - } - for i := 0; i < len(sources); i++ { - fileName := GetFileName(sources[i]) - untarCmd := fmt.Sprintf("tar xvf %s -C %s", path.Join(downloadDirs[i], fileName), downloadDirs[i]) - _, _, err = RetryCmd(untarCmd, RETRY_TIMES) - } - - return -} - -func CreateSymlink(dictName, service, version, depend, dataPath, deployPath string, shardSeq, - split int) (err error) { - downloadDirs, err := GetDownloadDirs(dictName, service, version, depend, deployPath, shardSeq, split) - if err != nil { - logex.Warningf("CreateSymlink failed, error is (+%v)", err) - } - for i, downloadDir := range downloadDirs { - cmd := fmt.Sprintf("ls -l %s | grep -E 'data.|index.' | awk '{print $NF}'", downloadDir) - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - - if err == nil && stdout != "" { - fileNameLi := strings.Split(strings.TrimSpace(stdout), "\n") - versionDir := getParentDir(version, depend) - versionFile := path.Join(dataPath, "VERSION") - dataSubPath := "" - if split > 1 { - dataSubPath = path.Join(dataPath, strconv.Itoa(i), versionDir) - } else { - dataSubPath = path.Join(dataPath, versionDir) - } - if err = os.MkdirAll(dataSubPath, 0755); err != nil { - // return err - logex.Warningf("os.Mkdir %s failed, err:[%v]", dataSubPath, err) - } - if dataSubPath != "" { - cmd = fmt.Sprintf("find %s/.. -type d -ctime +5 -print | xargs -i rm -rf {}", dataSubPath) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } - for _, file := range fileNameLi { - dataLink := "" - tempDataPath := "" - if split > 1 { - dataLink = path.Join(dataPath, strconv.Itoa(i), file) - tempDataPath = path.Join(dataPath, strconv.Itoa(i)) - } else { - dataLink = path.Join(dataPath, file) - tempDataPath = dataPath - } - cmd = fmt.Sprintf("rm -rf %s", dataLink) - _, stderr, _ := RetryCmd(cmd, RETRY_TIMES) - logex.Noticef("rm -rf %s, err:[%s]", dataLink, stderr) - - // create new symlink - err = os.Symlink(path.Join(downloadDir, file), dataLink) - logex.Noticef("os.Symlink %s %s return (%+v)", path.Join(downloadDir, file), dataLink, err) - fmt.Println("os.Symlink: ", path.Join(downloadDir, file), dataLink, err) - cmd = fmt.Sprintf("cp -d %s/index.* %s/", tempDataPath, dataSubPath) - _, stderr, _ = RetryCmd(cmd, RETRY_TIMES) - logex.Noticef("cp -d index Symlink to version dir %s, err:[%s]", dataSubPath, stderr) - cmd = fmt.Sprintf("cp -d %s/data.* %s/", tempDataPath, dataSubPath) - _, stderr, _ = RetryCmd(cmd, RETRY_TIMES) - logex.Noticef("cp -d data Symlink to version dir %s, err:[%s]", dataSubPath, stderr) - } - cmd = fmt.Sprintf("echo %s > %s", versionDir, versionFile) - if _, _, err = RetryCmd(cmd, RETRY_TIMES); err != nil { - return err - } - } - } - - return -} - -func (work *Work) CheckToReload() bool { - statusCmd := fmt.Sprintf("curl -s -d '{\"cmd\":\"status\"}' http://127.0.0.1:%s/ControlService/cmd", work.Port) - stdout, _, _ := RetryCmd(statusCmd, RETRY_TIMES) - var resp CubeResp - json.Unmarshal([]byte(stdout), &resp) - version := getParentDir(work.Version, work.Depend) - - if resp.CurVersion == "" && resp.BgVersion == "" { - logex.Noticef("cube version empty") - return true - } - if resp.CurVersion == version || resp.BgVersion == version { - logex.Noticef("cube version has matched. version: %s", version) - return false - } - return true -} - -func (work *Work) Reload() (err error) { - if work.Port == "" { - err = errors.New("Reload with invalid port.") - return - } - if !work.CheckToReload() { - work.writeStatus("finish_reload", "succ") - return - } - work.writeStatus("prepare_reload", "") - - var stdout string - versionPath := getParentDir(work.Version, work.Depend) - bgLoadCmd := "bg_load_base" - if work.Mode == "delta" { - bgLoadCmd = "bg_load_patch" - } - if work.ActiveVersionList == "" { - work.ActiveVersionList = "[]" - } - for i := 0; i < RELOAD_RETRY_TIMES; i++ { - reloadCmd := fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"%s\",\"version_path\":\"/%s\"}' http://127.0.0.1:%s/ControlService/cmd", bgLoadCmd, versionPath, work.Port) - fmt.Println("reload: ", reloadCmd) - stdout, _, _ = RetryCmd(reloadCmd, 1) - fmt.Println("reload stdout: ", stdout) - if strings.TrimSpace(stdout) == "200" { - logex.Debugf("bg_load_base return succ") - break - } else { - logex.Warning("bg_load_base return failed") - time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second) - } - } - - if strings.TrimSpace(stdout) == "200" { - work.writeStatus("finish_reload", "succ") - } else { - work.writeStatus("finish_reload", "failed") - err = errors.New("reload failed.") - } - - return -} - -func (work *Work) Clear() (err error) { - work.Service = "" - - var stdout string - var clearCmd string - for i := 0; i < RETRY_TIMES; i++ { - clearCmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"clear\",\"table_name\":\"%s\"}' http://127.0.0.1:%s/NodeControlService/cmd", work.DictName, work.Port) - fmt.Println("clear: ", clearCmd) - stdout, _, _ = RetryCmd(clearCmd, 1) - fmt.Println("clear stdout: ", stdout) - if strings.TrimSpace(stdout) == "200" { - logex.Debugf("clear return succ") - break - } else { - logex.Warning("clear return failed") - time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second) - } - } - - if strings.TrimSpace(stdout) == "200" { - err = work.writeStatus("succ", "") - } else { - err = work.writeStatus("failed", "") - } - - return -} - -func (work *Work) Check() (err error) { - if work.Service != "" || work.VersionSign == "" { - return - } - var dataLinkDirs []string - split := len(strings.Split(work.Source, ";")) - dataLinkDirs = GetDataLinkDirs(work.DictName, work.Service, work.Version, work.Depend, - work.DeployPath, work.ShardSeq, split) - - if _, t_err := os.Stat(work.DeployPath); os.IsNotExist(t_err) { - logex.Noticef("check DeployPath[%s] not exists.", work.DeployPath) - return - } - - check_succ := true - for _, linkDir := range dataLinkDirs { - parentDir, _ := filepath.Split(linkDir) - - cmd := fmt.Sprintf("ls -l %s | grep %s | awk '{print $9}' | grep -v data | grep -v index", parentDir, work.Depend) - - stdout, _, err := RetryCmd(cmd, RETRY_TIMES) - if err != nil || stdout == "" { - check_succ = false - break - } else { - versionList := strings.Split(strings.TrimSpace(stdout), "\n") - logex.Noticef("calc ver_sign for [%v]", versionList) - - var version_sign string - var version string - for i := 0; i < len(versionList); i++ { - split_index := strings.Index(versionList[i], "_") - if split_index > 0 && split_index < len(versionList[i]) { - version = versionList[i][split_index+1:] - } else { - version = versionList[i] - } - if version_sign == "" { - version_sign = fmt.Sprintf("%x", md5.Sum([]byte(version))) - } else { - version_sign = fmt.Sprintf("%x", md5.Sum([]byte(version_sign))) - } - } - - if version_sign != work.VersionSign { - logex.Warningf("version_sign check failed. real[%v] expect[%v]", version_sign, work.VersionSign) - check_succ = false - break - } - } - } - - if !check_succ { - work.clearPatchData() - work.clearPatchLink() - master_host, master_port, _ := GetMaster(work.MasterAddress) - cmd := fmt.Sprintf("cd %s && export STRATEGY_DIR=%s && ./downloader -h %s -p %s -d %s -s %d", - work.DeployPath, work.DeployPath, master_host, master_port, work.DictName, work.ShardSeq) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - } - - return -} - -func (work *Work) Enable() (err error) { - if work.Port == "" { - err = errors.New("Enable with invalid port") - return - } - var stdout string - var cmd string - versionPath := getParentDir(work.Version, work.Depend) - for i := 0; i < RELOAD_RETRY_TIMES; i++ { - if work.Service != "" { - cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"reload_model\",\"version\":\"%s-%s\",\"dict_name\":\"%s\"}' http://127.0.0.1:%s/ControlService/cmd", - versionPath, work.DictName, work.DictName, work.Port) - } else { - cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"enable\",\"version\":\"%s\"}' http://127.0.0.1:%s/ControlService/cmd", versionPath, work.Port) - } - stdout, _, _ = RetryCmd(cmd, 1) - - if strings.TrimSpace(stdout) == "200" { - logex.Debugf("enable return succ for %s, work dir is %s", work.Source, work.DeployPath) - break - } else { - logex.Warningf("enable return failed for %s, work dir is %s, error is (%+v)", work.Source, work.DeployPath, err) - time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second) - } - } - - if strings.TrimSpace(stdout) == "200" { - err = work.writeStatus("succ", "") - } else { - err = work.writeStatus("failed", "") - } - - if work.Service == "" { - cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"bg_unload\"}' http://127.0.0.1:%s/ControlService/cmd", work.Port) - stdout, _, _ = RetryCmd(cmd, RETRY_TIMES) - if strings.TrimSpace(stdout) == "200" { - logex.Debugf("unload return succ") - } else { - logex.Warning("unload return failed") - } - } - - RemoveStateFile(work.DictName, work.ShardSeq, work.Service) - - return -} - -func (work *Work) Pop() (err error) { - var stdout string - var cmd string - if work.ActiveVersionList == "" { - work.ActiveVersionList = "[]" - } - for i := 0; i < RELOAD_RETRY_TIMES; i++ { - cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"pop\",\"table_name\":\"%s\",\"active_versions\":%v}' http://127.0.0.1:%s/NodeControlService/cmd", work.DictName, work.ActiveVersionList, work.Port) - fmt.Println("pop: ", cmd) - stdout, _, _ = RetryCmd(cmd, 1) - fmt.Println("pop stdout: ", stdout) - if strings.TrimSpace(stdout) == "200" { - logex.Debugf("pop return succ") - break - } else { - logex.Warning("pop return failed") - time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second) - } - } - - if strings.TrimSpace(stdout) == "200" { - err = work.writeStatus("succ", "") - } else { - err = work.writeStatus("failed", "") - } - - RemoveStateFile(work.DictName, work.ShardSeq, work.Service) - return -} - -func writeStateFile(dictName string, shardSeq int, service, state string) { - stateFile := fmt.Sprintf(".state_%s_%d", dictName, shardSeq) - if service != "" { - stateFile = stateFile + "_" + service - } - - cmd := fmt.Sprintf("echo '%s' > %s/state/%s", state, Dir, stateFile) - if _, _, err := RetryCmd(cmd, RETRY_TIMES); err != nil { - logex.Warningf("%s error (%+v)", cmd, err) - } -} - -func RemoveStateFile(dictName string, shardSeq int, service string) { - stateFile := fmt.Sprintf(".state_%s_%d", dictName, shardSeq) - if service != "" { - stateFile = stateFile + "_" + service - } - - cmd := fmt.Sprintf("rm -f %s/state/%s", Dir, stateFile) - if _, _, err := RetryCmd(cmd, RETRY_TIMES); err != nil { - logex.Warningf("%s error (%+v)", cmd, err) - } -} - -func (work *Work) writeStatus(status string, finishStatus string) (err error) { - work.Status = status - work.FinishStatus = finishStatus - state, _ := json.Marshal(work) - writeStateFile(work.DictName, work.ShardSeq, work.Service, string(state)) - return -} - -func DoDownloadIndividual(source, downloadDir string, isService bool, timeOut int, ch chan error, wg *sync.WaitGroup) { - err := errors.New("DoDownloadIndividual start") - for i := 0; i < RETRY_TIMES; i++ { - err = FtpDownload(source, downloadDir, timeOut) - if err == nil { - logex.Debugf("download %s to %s succ", source, downloadDir) - if !isService { - err = FtpDownload(source+".md5", downloadDir, timeOut) - } - } else { - logex.Warningf("download error , source %s, downloadDir %s, err (%+v)", source, downloadDir, err) - continue - } - - if err == nil && isService { - // touch download_succ file - cmd := fmt.Sprintf("touch %s", path.Join(downloadDir, DOWNLOAD_DONE_MARK_FILE)) - RetryCmd(cmd, RETRY_TIMES) - break - } - - // download md5 file succ, md5check - if err == nil { - // md5sum -c - fileName := GetFileName(source) - err = checkMd5(path.Join(downloadDir, fileName), path.Join(downloadDir, fileName+".md5")) - logex.Warningf("md5sum check %s %s return (%+v)", path.Join(downloadDir, fileName), path.Join(downloadDir, fileName+".md5"), err) - if err == nil { - // touch download_succ file - cmd := fmt.Sprintf("touch %s", path.Join(downloadDir, DOWNLOAD_DONE_MARK_FILE)) - RetryCmd(cmd, RETRY_TIMES) - logex.Debugf("md5sum ok, source is %s, dir is %s", source, downloadDir) - break - } else { - logex.Warningf("md5sum error, source is %s, dir is %s", source, downloadDir) - continue - } - } else { - logex.Warningf("download %s return (%+v)", source+".md5", err) - continue - } - } - - ch <- err - wg.Done() -} - -func checkSources(source string) ([]string, error) { - sources := strings.Split(source, ";") - for i := 0; i < len(sources); i++ { - if sources[i] == "" || (!strings.HasPrefix(sources[i], "ftp://") && !strings.HasPrefix(sources[i], "http://")) { - return sources, errors.New("Invalid sources") - } - } - return sources, nil -} - -func DoDownload(dictName, service, version, depend, mode, source, deployPath string, - shardSeq int) (err error) { - sources, err := checkSources(source) - if err != nil { - logex.Warningf("checkSources %s return (%+v)", source, err) - return - } - downloadDirs, err := GetDownloadDirs(dictName, service, version, depend, deployPath, shardSeq, - len(sources)) - if err != nil { - logex.Warningf("GetDownloadDirs %s return (%+v)", source, err) - return - } - version_suffix := "" - if service != "" { - version_suffix = version_suffix + "-" + dictName - } - if !checkToDownload(downloadDirs) { - cmd := fmt.Sprintf("cd %s/cube_data && echo %s > VERSION && cp VERSION VERSION-%s", - deployPath, getParentDir(version, depend)+version_suffix, dictName) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - logex.Debugf("echo VERSION cmd:[%s] err:[%v]", cmd, err) - return - } - - ch := make(chan error, len(sources)) - wg := sync.WaitGroup{} - j := 0 - numCo := 0 - for ; j < len(sources); j++ { - if numCo >= MAX_DOWN_CO { - wg.Wait() - logex.Noticef("DoDownload co down.") - numCo = 0 - } - numCo += 1 - wg.Add(1) - time.Sleep(2000 * time.Millisecond) - timeOut := 900 - if mode == "base" { - timeOut = 3600 - } - go DoDownloadIndividual(sources[j], downloadDirs[j], (service != ""), timeOut, ch, &wg) - } - wg.Wait() - close(ch) - for err = range ch { - if err != nil { - return - } - } - cmd := fmt.Sprintf("cd %s/cube_data && echo %s > VERSION && cp VERSION VERSION-%s", - deployPath, getParentDir(version, depend)+version_suffix, dictName) - _, _, err = RetryCmd(cmd, RETRY_TIMES) - logex.Debugf("echo VERSION cmd:[%s] err:[%v]", cmd, err) - return -} - -func FtpDownload(source string, dest string, timeOut int) (err error) { - dlCmd := fmt.Sprintf("wget --quiet --level=100 -P %s %s --limit-rate=10240k", dest, source) - fmt.Println(dlCmd) - - _, _, err = RetryCmdWithSleep(dlCmd, RETRY_TIMES) - return -} - -func checkToDownload(downloadDirs []string) bool { - for _, v := range downloadDirs { - if _, err := os.Stat(path.Join(v, DOWNLOAD_DONE_MARK_FILE)); err != nil { - logex.Noticef("check [%v] not exists.", v) - return true - } - } - - return false -} - -// simple hash -func getDownloadDisk(dictName string, shardSeq int) string { - index := len(dictName) * shardSeq % len(disks) - - return disks[index] -} - -func getParentDir(version string, depend string) (dir string) { - if version == depend { - dir = depend - } else { - dir = depend + "_" + version - } - - return -} - -func GetFileName(source string) (fileName string) { - s := strings.Split(source, "/") - fileName = s[len(s)-1] - - return -} - -func checkMd5(file string, fileMd5 string) (err error) { - cmd := fmt.Sprintf("md5sum %s | awk '{print $1}'", file) - stdout, _, _ := pipeline.Run(exec.Command("/bin/sh", "-c", cmd)) - real_md5 := stdout.String() - cmd = fmt.Sprintf("cat %s | awk '{print $1}'", fileMd5) - stdout, _, _ = pipeline.Run(exec.Command("/bin/sh", "-c", cmd)) - given_md5 := stdout.String() - - if real_md5 != given_md5 { - logex.Warningf("checkMd5 failed real_md5[%s] given_md5[%s]", real_md5, given_md5) - err = errors.New("checkMd5 failed") - } - - return -} - -func RetryCmd(cmd string, retryTimes int) (stdoutStr string, stderrStr string, err error) { - for i := 0; i < retryTimes; i++ { - stdout, stderr, e := pipeline.Run(exec.Command("/bin/sh", "-c", cmd)) - stdoutStr = stdout.String() - stderrStr = stderr.String() - err = e - - logex.Debugf("cmd %s, stdout: %s, stderr: %s, err: (%+v)", cmd, stdoutStr, stderrStr, err) - if err == nil { - break - } - } - - return -} - -func RetryCmdWithSleep(cmd string, retryTimes int) (stdoutStr string, stderrStr string, err error) { - for i := 0; i < retryTimes; i++ { - stdout, stderr, e := pipeline.Run(exec.Command("/bin/sh", "-c", cmd)) - stdoutStr = stdout.String() - stderrStr = stderr.String() - err = e - - logex.Debugf("cmd %s, stdout: %s, stderr: %s, err: (%+v)", cmd, stdoutStr, stderrStr, err) - if err == nil { - break - } - time.Sleep(10000 * time.Millisecond) - } - - return -} diff --git a/core/cube/cube-agent/src/agent/work_pool.go b/core/cube/cube-agent/src/agent/work_pool.go deleted file mode 100644 index aecf00cc645c5f94a49d30ed1abccb89ab20ce4c..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/agent/work_pool.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package agent - -import ( - "errors" - "fmt" - "sync" - "sync/atomic" -) - -type ( - workType struct { - poolWorker PoolWorker - resultChannel chan error - } - - WorkPool struct { - queueChannel chan workType - workChannel chan PoolWorker - queuedWorkNum int32 - activeWorkerNum int32 - queueCapacity int32 - workFilter sync.Map - } -) - -type PoolWorker interface { - Token() string - DoWork() -} - -func NewWorkPool(workerNum int, queueCapacity int32) *WorkPool { - workPool := WorkPool{ - queueChannel: make(chan workType), - workChannel: make(chan PoolWorker, queueCapacity), - queuedWorkNum: 0, - activeWorkerNum: 0, - queueCapacity: queueCapacity, - } - - for i := 0; i < workerNum; i++ { - go workPool.startWorkRoutine() - } - - go workPool.startQueueRoutine() - - return &workPool -} - -func (workPool *WorkPool) startWorkRoutine() { - for { - select { - case work := <-workPool.workChannel: - workPool.doWork(work) - break - } - } -} - -func (workPool *WorkPool) startQueueRoutine() { - for { - select { - case queueItem := <-workPool.queueChannel: - if atomic.AddInt32(&workPool.queuedWorkNum, 0) == workPool.queueCapacity { - queueItem.resultChannel <- fmt.Errorf("work pool fulled with %v pending works", QueueCapacity) - continue - } - - atomic.AddInt32(&workPool.queuedWorkNum, 1) - - workPool.workChannel <- queueItem.poolWorker - - queueItem.resultChannel <- nil - - break - } - } -} - -func (workPool *WorkPool) doWork(poolWorker PoolWorker) { - defer atomic.AddInt32(&workPool.activeWorkerNum, -1) - defer workPool.workFilter.Delete(poolWorker.Token()) - - atomic.AddInt32(&workPool.queuedWorkNum, -1) - atomic.AddInt32(&workPool.activeWorkerNum, 1) - - poolWorker.DoWork() -} - -func (workPool *WorkPool) PostWorkWithToken(poolWorker PoolWorker) (err error) { - if _, ok := workPool.workFilter.Load(poolWorker.Token()); ok { - return errors.New("another work with same key is doing.") - } - workPool.workFilter.Store(poolWorker.Token(), true) - return workPool.PostWork(poolWorker) -} - -func (workPool *WorkPool) PostWork(poolWorker PoolWorker) (err error) { - work := workType{poolWorker, make(chan error)} - - defer close(work.resultChannel) - - workPool.queueChannel <- work - - err = <-work.resultChannel - - return err -} diff --git a/core/cube/cube-agent/src/cube-agent.go b/core/cube/cube-agent/src/cube-agent.go deleted file mode 100644 index 34f74979001c49139ba0fc14df44f6d3210dcef3..0000000000000000000000000000000000000000 --- a/core/cube/cube-agent/src/cube-agent.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "agent" - "fmt" - "github.com/Badangel/logex" - "github.com/docopt/docopt-go" - "os" - "path/filepath" - "runtime" - "strconv" -) - -func main() { - runtime.GOMAXPROCS(runtime.NumCPU()) - - agent.Dir, _ = filepath.Abs(filepath.Dir(os.Args[0])) - usage := fmt.Sprintf(`Usage: ./m_master [options] - -Options: - -n WORKERNUM set worker num. - -q QUEUENUM set queue num. - -P LISTEN_PORT agent listen port - -Log options: - -l LOG_LEVEL set log level, values: 0,1,2,4,8,16. [default: 16] - --log_dir=DIR set log output dir. [default: ./log] - --log_name=NAME set log name. [default: m_agent]`, agent.Dir) - - opts, err := docopt.Parse(usage, nil, true, "Cube Agent Checker 1.0.0", false) - if err != nil { - fmt.Println("ERROR:", err) - os.Exit(1) - } - - log_level, _ := strconv.Atoi(opts["-l"].(string)) - log_name := opts["--log_name"].(string) - log_dir := opts["--log_dir"].(string) - logex.SetLevel(getLogLevel(log_level)) - if err := logex.SetUpFileLogger(log_dir, log_name, nil); err != nil { - fmt.Println("ERROR:", err) - } - - logex.Notice("--- NEW SESSION -------------------------") - logex.Notice(">>> log_level:", log_level) - - agent.WorkerNum = 10 - if opts["-n"] != nil { - n, err := strconv.Atoi(opts["-n"].(string)) - if err == nil { - agent.WorkerNum = n - } - } - - agent.QueueCapacity = 20 - if opts["-q"] != nil { - q, err := strconv.Atoi(opts["-q"].(string)) - if err == nil { - agent.QueueCapacity = int32(q) - } - } - - agent.CmdWorkPool = agent.NewWorkPool(agent.WorkerNum, agent.QueueCapacity) - - if opts["-P"] == nil { - logex.Fatalf("ERROR: -P LISTEN PORT must be set!") - os.Exit(255) - } - - agentPort := opts["-P"].(string) - logex.Notice(">>> starting server...") - addr := ":" + agentPort - - if agent.StartHttp(addr) != nil { - logex.Noticef("cant start http(addr=%v). quit.", addr) - os.Exit(0) - } -} - -func getLogLevel(log_level int) logex.Level { - switch log_level { - case 16: - return logex.DEBUG - case 8: - return logex.TRACE - case 4: - return logex.NOTICE - case 2: - return logex.WARNING - case 1: - return logex.FATAL - case 0: - return logex.NONE - } - return logex.DEBUG -} diff --git a/core/general-server/op/general_infer_op.cpp b/core/general-server/op/general_infer_op.cpp index 742d27ef4612b8b201f8b21b5058dbf7525c9a9d..6f5dab9bcebd6f1f696e2af765f84c035a170217 100644 --- a/core/general-server/op/general_infer_op.cpp +++ b/core/general-server/op/general_infer_op.cpp @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "examples/demo-serving/op/general_infer_op.h" #include #include #include #include +#include "core/general-server/op/general_infer_op.h" +#include "core/general-server/op/general_reader_op.h" #include "core/predictor/framework/infer.h" #include "core/predictor/framework/memory.h" #include "core/predictor/framework/resource.h" -#include "examples/demo-serving/op/general_reader_op.h" + namespace baidu { namespace paddle_serving { diff --git a/core/general-server/op/general_infer_op.h b/core/general-server/op/general_infer_op.h index ca839054e0f11b40fd5f461307f3121d338028f8..f6860f0941afb941623bf9b876e128c06f5a0911 100644 --- a/core/general-server/op/general_infer_op.h +++ b/core/general-server/op/general_infer_op.h @@ -23,7 +23,7 @@ #else #include "paddle_inference_api.h" // NOLINT #endif -#include "examples/demo-serving/general_model_service.pb.h" +#include "core/general-server/general_model_service.pb.h" namespace baidu { namespace paddle_serving { diff --git a/core/general-server/op/general_reader_op.cpp b/core/general-server/op/general_reader_op.cpp index b692ba9796dc47f9710f7db96372636f7b42140a..1db8620c566f270fc4697781f9080d1bd0967fce 100644 --- a/core/general-server/op/general_reader_op.cpp +++ b/core/general-server/op/general_reader_op.cpp @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "examples/demo-serving/op/general_reader_op.h" #include #include #include #include +#include "core/general-server/op/general_reader_op.h" #include "core/predictor/framework/infer.h" #include "core/predictor/framework/memory.h" diff --git a/core/general-server/op/general_reader_op.h b/core/general-server/op/general_reader_op.h index ce68dcaee53d68d707defeeeacd5dee2981120d0..4c68d70c37e79bf86838551d899f6cc25b2be923 100644 --- a/core/general-server/op/general_reader_op.h +++ b/core/general-server/op/general_reader_op.h @@ -25,8 +25,8 @@ #endif #include #include "core/predictor/framework/resource.h" -#include "examples/demo-serving/general_model_service.pb.h" -#include "examples/demo-serving/load_general_model_service.pb.h" +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/load_general_model_service.pb.h" namespace baidu { namespace paddle_serving { diff --git a/examples/demo-serving/op/general_infer_op.cpp b/examples/demo-serving/op/general_infer_op.cpp deleted file mode 100644 index 742d27ef4612b8b201f8b21b5058dbf7525c9a9d..0000000000000000000000000000000000000000 --- a/examples/demo-serving/op/general_infer_op.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "examples/demo-serving/op/general_infer_op.h" -#include -#include -#include -#include -#include "core/predictor/framework/infer.h" -#include "core/predictor/framework/memory.h" -#include "core/predictor/framework/resource.h" -#include "examples/demo-serving/op/general_reader_op.h" - -namespace baidu { -namespace paddle_serving { -namespace serving { - -using baidu::paddle_serving::predictor::MempoolWrapper; -using baidu::paddle_serving::predictor::general_model::Tensor; -using baidu::paddle_serving::predictor::general_model::Response; -using baidu::paddle_serving::predictor::general_model::FetchInst; -using baidu::paddle_serving::predictor::InferManager; - -int GeneralInferOp::inference() { - const GeneralReaderOutput *reader_out = - get_depend_argument("general_reader_op"); - if (!reader_out) { - LOG(ERROR) << "Failed mutable depended argument, op:" - << "general_reader_op"; - return -1; - } - - int reader_status = reader_out->reader_status; - if (reader_status != 0) { - LOG(ERROR) << "Read request wrong."; - return -1; - } - - const TensorVector *in = &reader_out->tensor_vector; - TensorVector *out = butil::get_object(); - int batch_size = (*in)[0].shape[0]; - // infer - if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) { - LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME; - return -1; - } - - Response *res = mutable_data(); - - for (int i = 0; i < batch_size; ++i) { - FetchInst *fetch_inst = res->add_insts(); - for (int j = 0; j < out->size(); ++j) { - Tensor *tensor = fetch_inst->add_tensor_array(); - tensor->set_elem_type(1); - if (out->at(j).lod.size() == 1) { - tensor->add_shape(-1); - } else { - for (int k = 1; k < out->at(j).shape.size(); ++k) { - tensor->add_shape(out->at(j).shape[k]); - } - } - } - } - - for (int i = 0; i < out->size(); ++i) { - float *data_ptr = static_cast(out->at(i).data.data()); - int cap = 1; - for (int j = 1; j < out->at(i).shape.size(); ++j) { - cap *= out->at(i).shape[j]; - } - if (out->at(i).lod.size() == 1) { - for (int j = 0; j < batch_size; ++j) { - for (int k = out->at(i).lod[0][j]; k < out->at(i).lod[0][j + 1]; k++) { - res->mutable_insts(j)->mutable_tensor_array(i)->add_data( - reinterpret_cast(&(data_ptr[k])), sizeof(float)); - } - } - } else { - for (int j = 0; j < batch_size; ++j) { - for (int k = j * cap; k < (j + 1) * cap; ++k) { - res->mutable_insts(j)->mutable_tensor_array(i)->add_data( - reinterpret_cast(&(data_ptr[k])), sizeof(float)); - } - } - } - } - /* - for (size_t i = 0; i < in->size(); ++i) { - (*in)[i].shape.clear(); - } - in->clear(); - butil::return_object(in); - - for (size_t i = 0; i < out->size(); ++i) { - (*out)[i].shape.clear(); - } - out->clear(); - butil::return_object(out); - } - */ - return 0; -} -DEFINE_OP(GeneralInferOp); - -} // namespace serving -} // namespace paddle_serving -} // namespace baidu diff --git a/examples/demo-serving/op/general_infer_op.h b/examples/demo-serving/op/general_infer_op.h deleted file mode 100644 index ca839054e0f11b40fd5f461307f3121d338028f8..0000000000000000000000000000000000000000 --- a/examples/demo-serving/op/general_infer_op.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#ifdef BCLOUD -#ifdef WITH_GPU -#include "paddle/paddle_inference_api.h" -#else -#include "paddle/fluid/inference/api/paddle_inference_api.h" -#endif -#else -#include "paddle_inference_api.h" // NOLINT -#endif -#include "examples/demo-serving/general_model_service.pb.h" - -namespace baidu { -namespace paddle_serving { -namespace serving { - -static const char* GENERAL_MODEL_NAME = "general_model"; - -class GeneralInferOp - : public baidu::paddle_serving::predictor::OpWithChannel< - baidu::paddle_serving::predictor::general_model::Response> { - public: - typedef std::vector TensorVector; - - DECLARE_OP(GeneralInferOp); - - int inference(); -}; - -} // namespace serving -} // namespace paddle_serving -} // namespace baidu diff --git a/examples/demo-serving/op/general_reader_op.cpp b/examples/demo-serving/op/general_reader_op.cpp deleted file mode 100644 index 5b27b86b51fa3eec7c569f4832d875e09ef02600..0000000000000000000000000000000000000000 --- a/examples/demo-serving/op/general_reader_op.cpp +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "examples/demo-serving/op/general_reader_op.h" -#include -#include -#include -#include -#include "core/predictor/framework/infer.h" -#include "core/predictor/framework/memory.h" - -namespace baidu { -namespace paddle_serving { -namespace serving { - -using baidu::paddle_serving::predictor::MempoolWrapper; -using baidu::paddle_serving::predictor::general_model::Tensor; -using baidu::paddle_serving::predictor::general_model::Request; -using baidu::paddle_serving::predictor::general_model::FeedInst; -using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; - -int conf_check(const Request *req, - const std::shared_ptr &model_config) { - int var_num = req->insts(0).tensor_array_size(); - VLOG(2) << "var num: " << var_num; - if (var_num != model_config->_feed_type.size()) { - LOG(ERROR) << "feed var number not match."; - return -1; - } - - VLOG(2) << "begin to checkout feed type"; - for (int i = 0; i < var_num; ++i) { - VLOG(2) << "feed type[" << i << "]: " << - model_config->_feed_type[i]; - if (model_config->_feed_type[i] != - req->insts(0).tensor_array(i).elem_type()) { - LOG(ERROR) << "feed type not match."; - return -1; - } - VLOG(2) << "feed shape size: " << model_config->_feed_shape[i].size(); - if (model_config->_feed_shape[i].size() == - req->insts(0).tensor_array(i).shape_size()) { - for (int j = 0; j < model_config->_feed_shape[i].size(); ++j) { - req->insts(0).tensor_array(i).shape(j); - if (model_config->_feed_shape[i][j] != - req->insts(0).tensor_array(i).shape(j)) { - LOG(ERROR) << "feed shape not match."; - return -1; - } - } - } else { - LOG(ERROR) << "feed shape not match."; - return -1; - } - } - return 0; -} - -int GeneralReaderOp::inference() { - // reade request from client - const Request *req = dynamic_cast(get_request_message()); - - int batch_size = req->insts_size(); - int input_var_num = 0; - - std::vector elem_type; - std::vector elem_size; - std::vector capacity; - - GeneralReaderOutput *res = mutable_data(); - TensorVector *in = &res->tensor_vector; - - if (!res) { - LOG(ERROR) << "Failed get op tls reader object output"; - } - if (batch_size <= 0) { - res->reader_status = -1; - return 0; - } - - int var_num = req->insts(0).tensor_array_size(); - VLOG(2) << "var num: " << var_num; - // read config - - LOG(INFO) << "start to call load general model_conf op"; - baidu::paddle_serving::predictor::Resource &resource = - baidu::paddle_serving::predictor::Resource::instance(); - - LOG(INFO) << "get resource pointer done."; - std::shared_ptr model_config = - resource.get_general_model_config(); - - LOG(INFO) << "print general model config done."; - - // check - res->reader_status = conf_check(req, model_config); - if (res->reader_status != 0) { - LOG(INFO) << "model conf of server:"; - resource.print_general_model_config(model_config); - return 0; - } - // package tensor - - elem_type.resize(var_num); - elem_size.resize(var_num); - capacity.resize(var_num); - paddle::PaddleTensor lod_tensor; - for (int i = 0; i < var_num; ++i) { - elem_type[i] = req->insts(0).tensor_array(i).elem_type(); - VLOG(2) << "var[" << i << "] has elem type: " << elem_type[i]; - if (elem_type[i] == 0) { // int64 - elem_size[i] = sizeof(int64_t); - lod_tensor.dtype = paddle::PaddleDType::INT64; - } else { - elem_size[i] = sizeof(float); - lod_tensor.dtype = paddle::PaddleDType::FLOAT32; - } - - if (req->insts(0).tensor_array(i).shape(0) == -1) { - lod_tensor.lod.resize(1); - lod_tensor.lod[0].push_back(0); - VLOG(2) << "var[" << i << "] is lod_tensor"; - } else { - lod_tensor.shape.push_back(batch_size); - capacity[i] = 1; - for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) { - int dim = req->insts(0).tensor_array(i).shape(k); - VLOG(2) << "shape for var[" << i << "]: " << dim; - capacity[i] *= dim; - lod_tensor.shape.push_back(dim); - } - VLOG(2) << "var[" << i << "] is tensor, capacity: " << capacity[i]; - } - if (i == 0) { - lod_tensor.name = "words"; - } else { - lod_tensor.name = "label"; - } - in->push_back(lod_tensor); - } - - for (int i = 0; i < var_num; ++i) { - if (in->at(i).lod.size() == 1) { - for (int j = 0; j < batch_size; ++j) { - const Tensor &tensor = req->insts(j).tensor_array(i); - int data_len = tensor.data_size(); - VLOG(2) << "tensor size for var[" << i << "]: " << tensor.data_size(); - int cur_len = in->at(i).lod[0].back(); - VLOG(2) << "current len: " << cur_len; - in->at(i).lod[0].push_back(cur_len + data_len); - VLOG(2) << "new len: " << cur_len + data_len; - } - in->at(i).data.Resize(in->at(i).lod[0].back() * elem_size[i]); - in->at(i).shape = {in->at(i).lod[0].back(), 1}; - VLOG(2) << "var[" << i - << "] is lod_tensor and len=" << in->at(i).lod[0].back(); - } else { - in->at(i).data.Resize(batch_size * capacity[i] * elem_size[i]); - VLOG(2) << "var[" << i - << "] is tensor and capacity=" << batch_size * capacity[i]; - } - } - - for (int i = 0; i < var_num; ++i) { - if (elem_type[i] == 0) { - int64_t *dst_ptr = static_cast(in->at(i).data.data()); - int offset = 0; - for (int j = 0; j < batch_size; ++j) { - for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) { - dst_ptr[offset + k] = - *(const int64_t *)req->insts(j).tensor_array(i).data(k).c_str(); - } - if (in->at(i).lod.size() == 1) { - offset = in->at(i).lod[0][j + 1]; - } else { - offset += capacity[i]; - } - } - } else { - float *dst_ptr = static_cast(in->at(i).data.data()); - int offset = 0; - for (int j = 0; j < batch_size; ++j) { - for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) { - dst_ptr[offset + k] = - *(const float *)req->insts(j).tensor_array(i).data(k).c_str(); - } - if (in->at(i).lod.size() == 1) { - offset = in->at(i).lod[0][j + 1]; - } else { - offset += capacity[i]; - } - } - } - } - - VLOG(2) << "read data from client success"; - // print request - std::ostringstream oss; - int64_t *example = reinterpret_cast((*in)[0].data.data()); - for (int i = 0; i < 10; i++) { - oss << *(example + i) << " "; - } - VLOG(2) << "head element of first feed var : " << oss.str(); - // - return 0; -} -DEFINE_OP(GeneralReaderOp); -} // namespace serving -} // namespace paddle_serving -} // namespace baidu diff --git a/examples/demo-serving/op/general_reader_op.h b/examples/demo-serving/op/general_reader_op.h deleted file mode 100644 index ce68dcaee53d68d707defeeeacd5dee2981120d0..0000000000000000000000000000000000000000 --- a/examples/demo-serving/op/general_reader_op.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#ifdef BCLOUD -#ifdef WITH_GPU -#include "paddle/paddle_inference_api.h" -#else -#include "paddle/fluid/inference/api/paddle_inference_api.h" -#endif -#else -#include "paddle_inference_api.h" // NOLINT -#endif -#include -#include "core/predictor/framework/resource.h" -#include "examples/demo-serving/general_model_service.pb.h" -#include "examples/demo-serving/load_general_model_service.pb.h" - -namespace baidu { -namespace paddle_serving { -namespace serving { - -struct GeneralReaderOutput { - std::vector tensor_vector; - int reader_status = 0; - - void Clear() { - size_t tensor_count = tensor_vector.size(); - for (size_t ti = 0; ti < tensor_count; ++ti) { - tensor_vector[ti].shape.clear(); - } - tensor_vector.clear(); - } - std::string ShortDebugString() const { return "Not implemented!"; } -}; - -class GeneralReaderOp : public baidu::paddle_serving::predictor::OpWithChannel< - GeneralReaderOutput> { - public: - typedef std::vector TensorVector; - - DECLARE_OP(GeneralReaderOp); - - int inference(); -}; - -} // namespace serving -} // namespace paddle_serving -} // namespace baidu diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 03be0866984f0c03c72c9a3fe14b0774183c1aa2..8a1ce27858acb8303c741f094b2b99b4e0b5f3b5 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -141,6 +141,9 @@ class Client(object): result = self.client_handle_.predict( float_slot, float_feed_names, int_slot, int_feed_names, fetch_names) + # TODO(guru4elephant): the order of fetch var name should be consistent with + # general_model_config, this is not friendly + # In the future, we need make the number of fetched variable changable result_map = {} for i, name in enumerate(fetch_names): result_map[name] = result[i] diff --git a/python/paddle_serving_client/io/__init__.py b/python/paddle_serving_client/io/__init__.py index eb042154a087b6d0252560682acf161c2afe4549..f3c41d7625d0e444d3ab5b584bb1b2d775700612 100644 --- a/python/paddle_serving_client/io/__init__.py +++ b/python/paddle_serving_client/io/__init__.py @@ -18,7 +18,7 @@ from paddle.fluid.framework import core from paddle.fluid.framework import default_main_program from paddle.fluid.framework import Program from paddle.fluid import CPUPlace -from paddle.fluid.io import save_persistables +from paddle.fluid.io import save_inference_model from ..proto import general_model_config_pb2 as model_conf import os @@ -27,19 +27,13 @@ def save_model(server_model_folder, feed_var_dict, fetch_var_dict, main_program=None): - if main_program is None: - main_program = default_main_program() - elif isinstance(main_program, CompiledProgram): - main_program = main_program._program - if main_program is None: - raise TypeError("program should be as Program type or None") - if not isinstance(main_program, Program): - raise TypeError("program should be as Program type or None") - executor = Executor(place=CPUPlace()) - save_persistables(executor, server_model_folder, - main_program) + feed_var_names = [feed_var_dict[x].name for x in feed_var_dict] + target_vars = fetch_var_dict.values() + + save_inference_model(server_model_folder, feed_var_names, + target_vars, executor, main_program=main_program) config = model_conf.GeneralModelConfig() @@ -71,10 +65,11 @@ def save_model(server_model_folder, config.fetch_var.extend([fetch_var]) cmd = "mkdir -p {}".format(client_config_folder) + os.system(cmd) - with open("{}/serving_client_conf.prototxt", "w") as fout: + with open("{}/serving_client_conf.prototxt".format(client_config_folder), "w") as fout: fout.write(str(config)) - with open("{}/serving_server_conf.prototxt", "w") as fout: + with open("{}/serving_server_conf.prototxt".format(server_model_folder), "w") as fout: fout.write(str(config)) diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py index e0bda091b9efbd8033f301a431dd905e779aec3b..c0c086f4ad0e31be4d2ab0c595d8d475b679f43a 100644 --- a/python/paddle_serving_server/__init__.py +++ b/python/paddle_serving_server/__init__.py @@ -175,7 +175,7 @@ class Server(object): def run_server(self): # just run server with system command # currently we do not load cube - command = "/home/users/dongdaxiang/github_develop/Serving/build_server/core/general-server" \ + command = "/home/users/dongdaxiang/github_develop/Serving/build_server/core/general-server/serving" \ " -enable_model_toolkit " \ "-inferservice_path {} " \ "-inferservice_file {} " \