提交 2d46b7d7 编写于 作者: G guru4elephant

move general_infer_op and general_reader_op

上级 cef6f52f
......@@ -134,7 +134,4 @@ add_subdirectory(paddle_inference)
endif()
add_subdirectory(python)
set(PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE})
set(PYTHON_LIBRARIES ${PYTHON_LIB})
#add_subdirectory(examples)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
set(SOURCE_FILE cube-agent.go)
add_go_executable(cube-agent ${SOURCE_FILE})
add_dependencies(cube-agent agent-docopt-go)
add_dependencies(cube-agent agent-logex)
add_dependencies(cube-agent agent-pipeline)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"errors"
_ "github.com/Badangel/logex"
"strings"
"sync"
)
var (
Dir string
WorkerNum int
QueueCapacity int32
MasterHost []string
MasterPort []string
TestHostname string
TestIdc string
ShardLock sync.RWMutex
CmdWorkPool *WorkPool
CmdWorkFilter sync.Map
)
type (
Status struct {
Status string `json:"status"`
Version string `json:"version"`
}
MasterResp struct {
Success string `json:"success"`
Message string `json:"message"`
Data string `json:"data"`
}
ShardInfo struct {
DictName string
ShardSeq int
SlotIdList string
DataDir string
Service string `json:"service,omitempty"`
Libcube string `json:"libcube,omitempty"`
}
CubeResp struct {
Status int `json:"status"`
CurVersion string `json:"cur_version"`
BgVersion string `json:"bg_version"`
}
)
var BUILTIN_STATUS = Status{"RUNNING", "3.0.0.1"}
var ShardInfoMap map[string]map[string]*ShardInfo
var disks []string
func GetMaster(master string) (host, port string, err error) {
if len(ShardInfoMap) < 1 {
return "", "", errors.New("empty master list.")
}
if master == "" {
return MasterHost[0], MasterPort[0], nil
}
if _, ok := ShardInfoMap[master]; ok {
m := strings.Split(master, ":")
if len(m) != 2 {
return MasterHost[0], MasterPort[0], nil
}
return m[0], m[1], nil
} else {
return MasterHost[0], MasterPort[0], nil
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Badangel/logex"
)
type handlerFunc func(subpath string, m map[string]string, b []byte) (string, string, error)
var ( // key = subpath; eg: path="/checker/job", key="job"
getHandler map[string]handlerFunc
putHandler map[string]handlerFunc
deleteHandler map[string]handlerFunc
postHandler map[string]handlerFunc
)
func StartHttp(addr string) error {
// init handlers:
initGetHandlers()
initPostHandlers()
http.HandleFunc("/agent/", handleRest)
logex.Notice("start http ", addr)
return http.ListenAndServe(addr, nil)
}
func handleRest(w http.ResponseWriter, r *http.Request) {
var (
req_log string
status int32
)
time_begin := time.Now()
cont_type := make([]string, 1, 1)
cont_type[0] = "application/json"
header := w.Header()
header["Content-Type"] = cont_type
w.Header().Add("Access-Control-Allow-Origin", "*")
m := parseHttpKv(r)
b, _ := ioutil.ReadAll(r.Body)
req_log = fmt.Sprintf("handle %v %v %v from %v, len(m)=%v, m=%+v",
r.Method, r.URL.Path, r.URL.RawQuery, r.RemoteAddr, len(m), m)
api := r.URL.Path
var showHandler map[string]handlerFunc
switch r.Method {
case "GET":
showHandler = getHandler
case "POST": // create
showHandler = postHandler
case "PUT": // update
showHandler = putHandler
case "DELETE":
showHandler = deleteHandler
default:
logex.Warningf(`{"error":1, "message":"unsupport method %v"}`, r.Method)
}
handler, ok := showHandler[api]
if !ok {
key_list := make([]string, 0, len(showHandler))
for key := range showHandler {
key_list = append(key_list, key)
}
status = 2
fmt.Fprintf(w, `{"success":"%v", "message":"wrong api", "method":"%s", "api":"%s", "api_list":"%v"}`,
status, r.Method, api, key_list)
logex.Noticef(`%v, time=%v, status=%v`,
req_log, time.Now().Sub(time_begin).Nanoseconds()/1000000, status)
return
}
var s string
rst, handle_log, err := handler(api, m, b)
if err == nil {
status = 0
s = fmt.Sprintf(`{"success":"%v", "message":"query ok", "data":%s}`, status, rst)
} else {
status = 255
s = fmt.Sprintf(`{"success":"%v", "message":%v, "data":%s}`,
status, quote(err.Error()), rst)
}
if isJsonDict(s) {
fmt.Fprintln(w, s)
} else {
logex.Fatalf("invalid json: %v", s)
}
if err == nil {
logex.Noticef(`%v, time=%v, status=%v, handle_log=%v`,
req_log, time.Now().Sub(time_begin).Nanoseconds()/1000000,
status, quote(handle_log))
} else {
logex.Noticef(`%v, time=%v, status=%v, err=%v, handle_log=%v`,
req_log, time.Now().Sub(time_begin).Nanoseconds()/1000000,
status, quote(err.Error()), quote(handle_log))
}
}
func parseHttpKv(r *http.Request) map[string]string {
r.ParseForm()
m := make(map[string]string)
for k, v := range r.Form {
switch k {
case "user": // remove @baidu.com for user
m[k] = strings.Split(v[0], "@")[0]
default:
m[k] = v[0]
}
}
// allow passing hostname for debug
if _, ok := m["hostname"]; !ok {
ip := r.RemoteAddr[:strings.Index(r.RemoteAddr, ":")]
m["hostname"], _ = getHostname(ip)
}
return m
}
// restReq sends a restful request to requrl and returns response body.
func restReq(method, requrl string, timeout int, kv *map[string]string) (string, error) {
logex.Debug("####restReq####")
logex.Debug(*kv)
data := url.Values{}
if kv != nil {
for k, v := range *kv {
logex.Trace("req set:", k, v)
data.Set(k, v)
}
}
if method == "GET" || method == "DELETE" {
requrl = requrl + "?" + data.Encode()
data = url.Values{}
}
logex.Notice(method, requrl)
req, err := http.NewRequest(method, requrl, bytes.NewBufferString(data.Encode()))
if err != nil {
logex.Warning("NewRequest failed:", err)
return "", err
}
if method == "POST" || method == "PUT" {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
}
client := &http.Client{}
client.Timeout = time.Duration(timeout) * time.Second
resp, err := client.Do(req)
if err != nil {
logex.Warning("Do failed:", err)
return "", err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
logex.Warning("resp status: " + resp.Status)
return "", errors.New("resp status: " + resp.Status)
}
body, err := ioutil.ReadAll(resp.Body)
return string(body), err
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"encoding/json"
"fmt"
)
func initGetHandlers() {
getHandler = map[string]handlerFunc{
"/agent/status": GetStatus,
}
}
func GetStatus(subpath string, m map[string]string, b []byte) (string, string, error) {
b, err := json.Marshal(BUILTIN_STATUS)
if err != nil {
return quote(""), "", fmt.Errorf("json marshal failed, %v", err)
}
return string(b), "", err
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"encoding/json"
"fmt"
"github.com/Badangel/logex"
)
func initPostHandlers() {
postHandler = map[string]handlerFunc{
"/agent/cmd": PostCmd,
}
}
func PostCmd(subpath string, m map[string]string, b []byte) (string, string, error) {
var work Work
err := json.Unmarshal(b, &work)
if err != nil {
logex.Warningf("Unmarshal from %s error (+%v)", string(b), err)
return quote(""), "", fmt.Errorf("Work json unmarshal work failed, %v", err)
}
if _, ok := CmdWorkFilter.Load(work.Token()); ok {
logex.Warningf("Another work with same token is doing. Token(%s)", work.Token())
return quote(""), "", fmt.Errorf("Another work with same key is doing.", err)
}
CmdWorkFilter.Store(work.Token(), true)
err = work.DoWork()
CmdWorkFilter.Delete(work.Token())
if err != nil {
return quote(""), "", fmt.Errorf("Do work failed.", err)
}
return quote(""), "", err
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Badangel/logex"
)
// restReq sends a restful request to requrl and returns response body.
func RestReq(method, requrl string, timeout int, kv *map[string]string) (string, error) {
data := url.Values{}
if kv != nil {
for k, v := range *kv {
//logex.Trace("req set:", k, v)
data.Set(k, v)
}
}
if method == "GET" || method == "DELETE" {
requrl = requrl + "?" + data.Encode()
data = url.Values{}
}
//logex.Notice(method, requrl)
req, err := http.NewRequest(method, requrl, bytes.NewBufferString(data.Encode()))
if err != nil {
logex.Warning("NewRequest failed:", err)
return "", err
}
if method == "POST" || method == "PUT" {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
}
client := &http.Client{}
client.Timeout = time.Duration(timeout) * time.Second
resp, err := client.Do(req)
if err != nil {
logex.Warning("Do failed:", err)
return "", err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
logex.Warning("resp status: " + resp.Status)
return "", errors.New("resp status: " + resp.Status)
}
body, err := ioutil.ReadAll(resp.Body)
return string(body), err
}
// restReq sends a restful request to requrl and returns response body as json.
func JsonReq(method, requrl string, timeout int, kv *map[string]string,
out interface{}) error {
s, err := RestReq(method, requrl, timeout, kv)
logex.Debugf("json request method:[%v], requrl:[%s], timeout:[%v], map[%v], out_str:[%s]", method, requrl, timeout, kv, s)
if err != nil {
return err
}
return json.Unmarshal([]byte(s), out)
}
func GetHdfsMeta(src string) (master, ugi, path string, err error) {
//src = "hdfs://root:rootpasst@st1-inf-platform0.st01.baidu.com:54310/user/mis_user/news_dnn_ctr_cube_1/1501836820/news_dnn_ctr_cube_1_part54.tar"
//src = "hdfs://st1-inf-platform0.st01.baidu.com:54310/user/mis_user/news_dnn_ctr_cube_1/1501836820/news_dnn_ctr_cube_1_part54.tar"
ugiBegin := strings.Index(src, "//")
ugiPos := strings.LastIndex(src, "@")
if ugiPos != -1 && ugiBegin != -1 {
ugi = src[ugiBegin+2 : ugiPos]
}
src1 := strings.Replace(strings.Replace(src, "hdfs://", "", 1), ugi, "", 1)
if ugi != "" {
src1 = src1[1:]
}
pos := strings.Index(src1, "/")
if pos != -1 {
master = src1[0:pos]
path = src1[pos:]
} else {
logex.Warningf("failed to get the master or path for (%s)", src)
err = errors.New("invalid master or path found")
}
logex.Debugf("parse the (%s) succ, master is %s, ugi is (%s), path is %s", src, master, ugi, path)
return
}
func getHostIp() (string, error) {
if addrs, err := net.InterfaceAddrs(); err == nil {
for _, addr := range addrs {
ips := addr.String()
logex.Debugf("get host ip: %v", ips)
if strings.HasPrefix(ips, "127") {
continue
} else {
list := strings.Split(ips, "/")
if len(list) != 2 {
continue
}
return list[0], nil
}
}
}
return "unkown ip", errors.New("get host ip failed")
}
func getHostname(ip string) (hostname string, err error) {
if hostnames, err := net.LookupAddr(ip); err != nil {
hostname = ip
//logex.Warningf("cannot find the hostname of ip (%s), error (%v)", ip, err)
} else {
if len(hostnames) > 0 {
hostname = hostnames[0]
} else {
hostname = ip
}
}
return hostname, err
}
func GetLocalHostname() (hostname string, err error) {
if ip, err := getHostIp(); err == nil {
return getHostname(ip)
} else {
return "unkown ip", err
}
}
func GetLocalHostnameCmd() (hostname string, err error) {
cmd := "hostname"
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if stdout != "" && err == nil {
hostname := strings.TrimSpace(stdout)
index := strings.LastIndex(hostname, ".baidu.com")
if index > 0 {
return hostname[:strings.LastIndex(hostname, ".baidu.com")], nil
} else {
return hostname, nil
}
} else {
logex.Debugf("using hostname cmd failed. err:%v", err)
return GetLocalHostname()
}
}
// quote quotes string for json output. eg: s="123", quote(s)=`"123"`
func quote(s string) string {
return fmt.Sprintf("%q", s)
}
// quoteb quotes byte array for json output.
func quoteb(b []byte) string {
return quote(string(b))
}
// quotea quotes string array for json output
func quotea(a []string) string {
b, _ := json.Marshal(a)
return string(b)
}
func isJsonDict(s string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(s), &js) == nil
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"github.com/Badangel/logex"
"github.com/Badangel/pipeline"
"os"
"os/exec"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
const (
COMMAND_DOWNLOAD = "download"
COMMAND_RELOAD = "reload"
COMMAND_SWITCH = "enable"
COMMAND_ROLLBACK = "rollback"
COMMAND_CHECK = "check"
COMMAND_CLEAR = "clear"
COMMAND_POP = "pop"
RETRY_TIMES = 3
REQUEST_MASTER_TIMEOUT_SECOND = 60
MAX_DOWN_CO = 7
RELOAD_RETRY_TIMES = 3
RELOAD_RETRY_INTERVAL_SECOND = 10
DOWNLOAD_DONE_MARK_FILE = ".download_done"
)
type VersionInfo struct {
Version string
Depend string
Source string
}
type Work struct {
DictName string `json:"dict_name"`
ShardSeq int `json:"shard_seq"`
DeployPath string `json:"deploy_path"`
Command string `json:"command"`
Version string `json:"version"`
Depend string `json:"depend"`
Source string `json:"source"`
Mode string `json:"mode"`
DictMode string `json:"dict_mode"`
Port string `json:"port"`
bRollback bool `json:"b_rollback"`
RollbackInfo []VersionInfo `json:"rollback_info"`
Status string `json:"status"`
FinishStatus string `json:"finish_status"`
Service string `json:"service,omitempty"`
VersionSign string `json:"version_sign,omitempty"`
MasterAddress string `json:"master_address,omitempty"`
ActiveVersionList string `json:"active_version_list,omitempty"`
}
func (work *Work) Token() string {
return work.DictName + strconv.Itoa(work.ShardSeq) + work.Service
}
func (work *Work) Valid() bool {
if work.Command == "" ||
work.Version == "" ||
work.Depend == "" {
return false
}
return true
}
func (work *Work) DoWork() error {
var err error
if !work.Valid() {
err = errors.New("Work is invalid")
return err
}
switch work.Command {
case COMMAND_DOWNLOAD:
err = work.Download()
case COMMAND_RELOAD:
err = work.Reload()
case COMMAND_SWITCH:
err = work.Enable()
case COMMAND_CHECK:
err = work.Check()
case COMMAND_CLEAR:
err = work.Clear()
case COMMAND_POP:
err = work.Pop()
default:
logex.Debugf("Invalid command %s received", work.Command)
err = errors.New("Invalid command.")
}
return err
}
func GetDownloadDirs(dictName, service, version, depend, deployPath string, shardSeq,
split int) ([]string, error) {
dirs := make([]string, 0, split)
if deployPath == "" {
return dirs, errors.New("Invalid Deploy path")
}
parentDir := getParentDir(version, depend)
if split < 2 {
disk := path.Join(deployPath, "cube_data")
if service == "" {
dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), parentDir))
} else {
dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), parentDir+"-"+dictName))
}
} else {
for i := 0; i < split; i++ {
disk := path.Join(deployPath, "cube_data")
if service == "" {
dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq), strconv.Itoa(i), parentDir))
} else {
dirs = append(dirs, path.Join(disk, strconv.Itoa(shardSeq),
parentDir+"-"+dictName))
}
}
}
return dirs, nil
}
func GetDataLinkDirs(dictName, service, version, depend, deployPath string, shardSeq,
split int) []string {
dirs := make([]string, 0, split)
parentDir := getParentDir(version, depend)
if split < 2 {
disk := path.Join(deployPath, "data")
if service == "" {
dirs = append(dirs, path.Join(disk, parentDir))
}
} else {
for i := 0; i < split; i++ {
disk := path.Join(deployPath, "data")
if service == "" {
dirs = append(dirs, path.Join(disk, strconv.Itoa(i), parentDir))
}
}
}
return dirs
}
func (work *Work) Download() (err error) {
err = DoDownload(work.DictName, work.Service, work.Version, work.Depend, work.Mode, work.Source,
work.DeployPath, work.ShardSeq)
if err != nil {
logex.Warningf("download error, failed to download %s, dir is %s, error is (+%v)", work.Source, work.DeployPath, err)
return
}
if work.Service == "" {
err = UnTar(work.DictName, work.Service, work.Version, work.Depend, work.Source,
work.DeployPath, work.ShardSeq)
if err == nil {
dataPath := path.Join(work.DeployPath, "data")
// remove all old links
if work.Mode == "base" || len(work.RollbackInfo) != 0 {
cmd := fmt.Sprintf("ls -l %s | grep -E 'data.|index.' | awk '{print $9}'", dataPath)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err == nil && stdout != "" {
fileNameLi := strings.Split(strings.TrimSpace(stdout), "\n")
for _, file := range fileNameLi {
err = os.Remove(path.Join(dataPath, file))
logex.Debugf("os.Remove(%s) error (%+v) ", path.Join(dataPath, file), err)
}
}
}
// create symbolic link to the version rollbacked
err = CreateSymlink(work.DictName, work.Service, work.Version, work.Depend, dataPath,
work.DeployPath, work.ShardSeq, len(strings.Split(work.Source, ";")))
} else {
logex.Warningf("download error, failed to untar for %s, dir is %s, error is (+%v)", work.Source, work.DeployPath, err)
}
}
if err == nil {
// clear history data
work.clearData()
work.clearLink()
} else {
logex.Warningf("create symlink failed, error is (+%v)", err)
}
return
}
func (work *Work) clearData() (err error) {
split := len(strings.Split(work.Source, ";"))
downloadDirs, err := GetDownloadDirs(work.DictName, work.Service, work.Version, work.Depend,
work.DeployPath, work.ShardSeq, split)
if err != nil {
logex.Warningf("clearData failed, error is (+%v)", err)
return
}
for _, downloadDir := range downloadDirs {
parentDir, _ := filepath.Split(downloadDir)
cmd := fmt.Sprintf("ls -l %s | grep -v %s | awk '{print $9}'", parentDir, work.Depend)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err != nil || stdout == "" || work.Service != "" {
cmd = fmt.Sprintf("find %s -type d -ctime +1 -print | xargs -i rm -rf {}", parentDir)
_, _, err = RetryCmd(cmd, RETRY_TIMES)
} else {
rmList := strings.Split(strings.TrimSpace(stdout), "\n")
for i := 0; i < len(rmList); i++ {
if rmList[i] == "" {
continue
}
cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i])
_, _, err = RetryCmd(cmd, RETRY_TIMES)
}
}
}
return
}
func (work *Work) clearPatchData() (err error) {
if work.Service != "" {
return
}
split := len(strings.Split(work.Source, ";"))
downloadDirs, err := GetDownloadDirs(work.DictName, work.Service, work.Version, work.Depend,
work.DeployPath, work.ShardSeq, split)
if err != nil {
logex.Warningf("clearPatchData failed, error is (+%v)", err)
return
}
for _, downloadDir := range downloadDirs {
parentDir, _ := filepath.Split(downloadDir)
cmd := fmt.Sprintf("ls -l %s | grep %s_ | awk '{print $9}'", parentDir, work.Depend)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err == nil && stdout != "" {
rmList := strings.Split(strings.TrimSpace(stdout), "\n")
for i := 0; i < len(rmList); i++ {
if rmList[i] == "" {
continue
}
cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i])
_, _, err = RetryCmd(cmd, RETRY_TIMES)
}
}
}
return
}
func (work *Work) clearLink() (err error) {
if work.Service != "" {
return
}
split := len(strings.Split(work.Source, ";"))
dataLinkDirs := GetDataLinkDirs(work.DictName, work.Service, work.Version, work.Depend,
work.DeployPath, work.ShardSeq, split)
for _, linkDir := range dataLinkDirs {
parentDir, _ := filepath.Split(linkDir)
cmd := fmt.Sprintf("ls -l %s | grep -v %s | awk '{print $9}'", parentDir, work.Depend)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err != nil || stdout == "" {
cmd = fmt.Sprintf("find %s -type d -ctime +1 -print | xargs -i rm -rf {}", parentDir)
_, _, err = RetryCmd(cmd, RETRY_TIMES)
} else {
rmList := strings.Split(strings.TrimSpace(stdout), "\n")
for i := 0; i < len(rmList); i++ {
if rmList[i] == "" {
continue
}
cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i])
_, _, err = RetryCmd(cmd, RETRY_TIMES)
}
}
}
return
}
func (work *Work) clearPatchLink() (err error) {
if work.Service != "" {
return
}
split := len(strings.Split(work.Source, ";"))
dataLinkDirs := GetDataLinkDirs(work.DictName, work.Service, work.Version, work.Depend,
work.DeployPath, work.ShardSeq, split)
for _, linkDir := range dataLinkDirs {
parentDir, _ := filepath.Split(linkDir)
cmd := fmt.Sprintf("ls -l %s | grep %s_ | awk '{print $9}'", parentDir, work.Depend)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err == nil && stdout != "" {
rmList := strings.Split(strings.TrimSpace(stdout), "\n")
for i := 0; i < len(rmList); i++ {
if rmList[i] == "" {
continue
}
cmd = fmt.Sprintf("rm -rf %s/%s*", parentDir, rmList[i])
_, _, err = RetryCmd(cmd, RETRY_TIMES)
}
}
}
return
}
func UnTar(dictName, service, version, depend, source, deployPath string, shardSeq int) (err error) {
sources := strings.Split(source, ";")
downloadDirs, err := GetDownloadDirs(dictName, service, version, depend, deployPath, shardSeq,
len(sources))
if err != nil {
logex.Warningf("UnTar failed, error is (+%v)", err)
return
}
for i := 0; i < len(sources); i++ {
fileName := GetFileName(sources[i])
untarCmd := fmt.Sprintf("tar xvf %s -C %s", path.Join(downloadDirs[i], fileName), downloadDirs[i])
_, _, err = RetryCmd(untarCmd, RETRY_TIMES)
}
return
}
func CreateSymlink(dictName, service, version, depend, dataPath, deployPath string, shardSeq,
split int) (err error) {
downloadDirs, err := GetDownloadDirs(dictName, service, version, depend, deployPath, shardSeq, split)
if err != nil {
logex.Warningf("CreateSymlink failed, error is (+%v)", err)
}
for i, downloadDir := range downloadDirs {
cmd := fmt.Sprintf("ls -l %s | grep -E 'data.|index.' | awk '{print $NF}'", downloadDir)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err == nil && stdout != "" {
fileNameLi := strings.Split(strings.TrimSpace(stdout), "\n")
versionDir := getParentDir(version, depend)
versionFile := path.Join(dataPath, "VERSION")
dataSubPath := ""
if split > 1 {
dataSubPath = path.Join(dataPath, strconv.Itoa(i), versionDir)
} else {
dataSubPath = path.Join(dataPath, versionDir)
}
if err = os.MkdirAll(dataSubPath, 0755); err != nil {
// return err
logex.Warningf("os.Mkdir %s failed, err:[%v]", dataSubPath, err)
}
if dataSubPath != "" {
cmd = fmt.Sprintf("find %s/.. -type d -ctime +5 -print | xargs -i rm -rf {}", dataSubPath)
_, _, err = RetryCmd(cmd, RETRY_TIMES)
}
for _, file := range fileNameLi {
dataLink := ""
tempDataPath := ""
if split > 1 {
dataLink = path.Join(dataPath, strconv.Itoa(i), file)
tempDataPath = path.Join(dataPath, strconv.Itoa(i))
} else {
dataLink = path.Join(dataPath, file)
tempDataPath = dataPath
}
cmd = fmt.Sprintf("rm -rf %s", dataLink)
_, stderr, _ := RetryCmd(cmd, RETRY_TIMES)
logex.Noticef("rm -rf %s, err:[%s]", dataLink, stderr)
// create new symlink
err = os.Symlink(path.Join(downloadDir, file), dataLink)
logex.Noticef("os.Symlink %s %s return (%+v)", path.Join(downloadDir, file), dataLink, err)
fmt.Println("os.Symlink: ", path.Join(downloadDir, file), dataLink, err)
cmd = fmt.Sprintf("cp -d %s/index.* %s/", tempDataPath, dataSubPath)
_, stderr, _ = RetryCmd(cmd, RETRY_TIMES)
logex.Noticef("cp -d index Symlink to version dir %s, err:[%s]", dataSubPath, stderr)
cmd = fmt.Sprintf("cp -d %s/data.* %s/", tempDataPath, dataSubPath)
_, stderr, _ = RetryCmd(cmd, RETRY_TIMES)
logex.Noticef("cp -d data Symlink to version dir %s, err:[%s]", dataSubPath, stderr)
}
cmd = fmt.Sprintf("echo %s > %s", versionDir, versionFile)
if _, _, err = RetryCmd(cmd, RETRY_TIMES); err != nil {
return err
}
}
}
return
}
func (work *Work) CheckToReload() bool {
statusCmd := fmt.Sprintf("curl -s -d '{\"cmd\":\"status\"}' http://127.0.0.1:%s/ControlService/cmd", work.Port)
stdout, _, _ := RetryCmd(statusCmd, RETRY_TIMES)
var resp CubeResp
json.Unmarshal([]byte(stdout), &resp)
version := getParentDir(work.Version, work.Depend)
if resp.CurVersion == "" && resp.BgVersion == "" {
logex.Noticef("cube version empty")
return true
}
if resp.CurVersion == version || resp.BgVersion == version {
logex.Noticef("cube version has matched. version: %s", version)
return false
}
return true
}
func (work *Work) Reload() (err error) {
if work.Port == "" {
err = errors.New("Reload with invalid port.")
return
}
if !work.CheckToReload() {
work.writeStatus("finish_reload", "succ")
return
}
work.writeStatus("prepare_reload", "")
var stdout string
versionPath := getParentDir(work.Version, work.Depend)
bgLoadCmd := "bg_load_base"
if work.Mode == "delta" {
bgLoadCmd = "bg_load_patch"
}
if work.ActiveVersionList == "" {
work.ActiveVersionList = "[]"
}
for i := 0; i < RELOAD_RETRY_TIMES; i++ {
reloadCmd := fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"%s\",\"version_path\":\"/%s\"}' http://127.0.0.1:%s/ControlService/cmd", bgLoadCmd, versionPath, work.Port)
fmt.Println("reload: ", reloadCmd)
stdout, _, _ = RetryCmd(reloadCmd, 1)
fmt.Println("reload stdout: ", stdout)
if strings.TrimSpace(stdout) == "200" {
logex.Debugf("bg_load_base return succ")
break
} else {
logex.Warning("bg_load_base return failed")
time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second)
}
}
if strings.TrimSpace(stdout) == "200" {
work.writeStatus("finish_reload", "succ")
} else {
work.writeStatus("finish_reload", "failed")
err = errors.New("reload failed.")
}
return
}
func (work *Work) Clear() (err error) {
work.Service = ""
var stdout string
var clearCmd string
for i := 0; i < RETRY_TIMES; i++ {
clearCmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"clear\",\"table_name\":\"%s\"}' http://127.0.0.1:%s/NodeControlService/cmd", work.DictName, work.Port)
fmt.Println("clear: ", clearCmd)
stdout, _, _ = RetryCmd(clearCmd, 1)
fmt.Println("clear stdout: ", stdout)
if strings.TrimSpace(stdout) == "200" {
logex.Debugf("clear return succ")
break
} else {
logex.Warning("clear return failed")
time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second)
}
}
if strings.TrimSpace(stdout) == "200" {
err = work.writeStatus("succ", "")
} else {
err = work.writeStatus("failed", "")
}
return
}
func (work *Work) Check() (err error) {
if work.Service != "" || work.VersionSign == "" {
return
}
var dataLinkDirs []string
split := len(strings.Split(work.Source, ";"))
dataLinkDirs = GetDataLinkDirs(work.DictName, work.Service, work.Version, work.Depend,
work.DeployPath, work.ShardSeq, split)
if _, t_err := os.Stat(work.DeployPath); os.IsNotExist(t_err) {
logex.Noticef("check DeployPath[%s] not exists.", work.DeployPath)
return
}
check_succ := true
for _, linkDir := range dataLinkDirs {
parentDir, _ := filepath.Split(linkDir)
cmd := fmt.Sprintf("ls -l %s | grep %s | awk '{print $9}' | grep -v data | grep -v index", parentDir, work.Depend)
stdout, _, err := RetryCmd(cmd, RETRY_TIMES)
if err != nil || stdout == "" {
check_succ = false
break
} else {
versionList := strings.Split(strings.TrimSpace(stdout), "\n")
logex.Noticef("calc ver_sign for [%v]", versionList)
var version_sign string
var version string
for i := 0; i < len(versionList); i++ {
split_index := strings.Index(versionList[i], "_")
if split_index > 0 && split_index < len(versionList[i]) {
version = versionList[i][split_index+1:]
} else {
version = versionList[i]
}
if version_sign == "" {
version_sign = fmt.Sprintf("%x", md5.Sum([]byte(version)))
} else {
version_sign = fmt.Sprintf("%x", md5.Sum([]byte(version_sign)))
}
}
if version_sign != work.VersionSign {
logex.Warningf("version_sign check failed. real[%v] expect[%v]", version_sign, work.VersionSign)
check_succ = false
break
}
}
}
if !check_succ {
work.clearPatchData()
work.clearPatchLink()
master_host, master_port, _ := GetMaster(work.MasterAddress)
cmd := fmt.Sprintf("cd %s && export STRATEGY_DIR=%s && ./downloader -h %s -p %s -d %s -s %d",
work.DeployPath, work.DeployPath, master_host, master_port, work.DictName, work.ShardSeq)
_, _, err = RetryCmd(cmd, RETRY_TIMES)
}
return
}
func (work *Work) Enable() (err error) {
if work.Port == "" {
err = errors.New("Enable with invalid port")
return
}
var stdout string
var cmd string
versionPath := getParentDir(work.Version, work.Depend)
for i := 0; i < RELOAD_RETRY_TIMES; i++ {
if work.Service != "" {
cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"reload_model\",\"version\":\"%s-%s\",\"dict_name\":\"%s\"}' http://127.0.0.1:%s/ControlService/cmd",
versionPath, work.DictName, work.DictName, work.Port)
} else {
cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"enable\",\"version\":\"%s\"}' http://127.0.0.1:%s/ControlService/cmd", versionPath, work.Port)
}
stdout, _, _ = RetryCmd(cmd, 1)
if strings.TrimSpace(stdout) == "200" {
logex.Debugf("enable return succ for %s, work dir is %s", work.Source, work.DeployPath)
break
} else {
logex.Warningf("enable return failed for %s, work dir is %s, error is (%+v)", work.Source, work.DeployPath, err)
time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second)
}
}
if strings.TrimSpace(stdout) == "200" {
err = work.writeStatus("succ", "")
} else {
err = work.writeStatus("failed", "")
}
if work.Service == "" {
cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"bg_unload\"}' http://127.0.0.1:%s/ControlService/cmd", work.Port)
stdout, _, _ = RetryCmd(cmd, RETRY_TIMES)
if strings.TrimSpace(stdout) == "200" {
logex.Debugf("unload return succ")
} else {
logex.Warning("unload return failed")
}
}
RemoveStateFile(work.DictName, work.ShardSeq, work.Service)
return
}
func (work *Work) Pop() (err error) {
var stdout string
var cmd string
if work.ActiveVersionList == "" {
work.ActiveVersionList = "[]"
}
for i := 0; i < RELOAD_RETRY_TIMES; i++ {
cmd = fmt.Sprintf("curl -o /dev/null -s -w %%{http_code} -d '{\"cmd\":\"pop\",\"table_name\":\"%s\",\"active_versions\":%v}' http://127.0.0.1:%s/NodeControlService/cmd", work.DictName, work.ActiveVersionList, work.Port)
fmt.Println("pop: ", cmd)
stdout, _, _ = RetryCmd(cmd, 1)
fmt.Println("pop stdout: ", stdout)
if strings.TrimSpace(stdout) == "200" {
logex.Debugf("pop return succ")
break
} else {
logex.Warning("pop return failed")
time.Sleep(RELOAD_RETRY_INTERVAL_SECOND * time.Second)
}
}
if strings.TrimSpace(stdout) == "200" {
err = work.writeStatus("succ", "")
} else {
err = work.writeStatus("failed", "")
}
RemoveStateFile(work.DictName, work.ShardSeq, work.Service)
return
}
func writeStateFile(dictName string, shardSeq int, service, state string) {
stateFile := fmt.Sprintf(".state_%s_%d", dictName, shardSeq)
if service != "" {
stateFile = stateFile + "_" + service
}
cmd := fmt.Sprintf("echo '%s' > %s/state/%s", state, Dir, stateFile)
if _, _, err := RetryCmd(cmd, RETRY_TIMES); err != nil {
logex.Warningf("%s error (%+v)", cmd, err)
}
}
func RemoveStateFile(dictName string, shardSeq int, service string) {
stateFile := fmt.Sprintf(".state_%s_%d", dictName, shardSeq)
if service != "" {
stateFile = stateFile + "_" + service
}
cmd := fmt.Sprintf("rm -f %s/state/%s", Dir, stateFile)
if _, _, err := RetryCmd(cmd, RETRY_TIMES); err != nil {
logex.Warningf("%s error (%+v)", cmd, err)
}
}
func (work *Work) writeStatus(status string, finishStatus string) (err error) {
work.Status = status
work.FinishStatus = finishStatus
state, _ := json.Marshal(work)
writeStateFile(work.DictName, work.ShardSeq, work.Service, string(state))
return
}
func DoDownloadIndividual(source, downloadDir string, isService bool, timeOut int, ch chan error, wg *sync.WaitGroup) {
err := errors.New("DoDownloadIndividual start")
for i := 0; i < RETRY_TIMES; i++ {
err = FtpDownload(source, downloadDir, timeOut)
if err == nil {
logex.Debugf("download %s to %s succ", source, downloadDir)
if !isService {
err = FtpDownload(source+".md5", downloadDir, timeOut)
}
} else {
logex.Warningf("download error , source %s, downloadDir %s, err (%+v)", source, downloadDir, err)
continue
}
if err == nil && isService {
// touch download_succ file
cmd := fmt.Sprintf("touch %s", path.Join(downloadDir, DOWNLOAD_DONE_MARK_FILE))
RetryCmd(cmd, RETRY_TIMES)
break
}
// download md5 file succ, md5check
if err == nil {
// md5sum -c
fileName := GetFileName(source)
err = checkMd5(path.Join(downloadDir, fileName), path.Join(downloadDir, fileName+".md5"))
logex.Warningf("md5sum check %s %s return (%+v)", path.Join(downloadDir, fileName), path.Join(downloadDir, fileName+".md5"), err)
if err == nil {
// touch download_succ file
cmd := fmt.Sprintf("touch %s", path.Join(downloadDir, DOWNLOAD_DONE_MARK_FILE))
RetryCmd(cmd, RETRY_TIMES)
logex.Debugf("md5sum ok, source is %s, dir is %s", source, downloadDir)
break
} else {
logex.Warningf("md5sum error, source is %s, dir is %s", source, downloadDir)
continue
}
} else {
logex.Warningf("download %s return (%+v)", source+".md5", err)
continue
}
}
ch <- err
wg.Done()
}
func checkSources(source string) ([]string, error) {
sources := strings.Split(source, ";")
for i := 0; i < len(sources); i++ {
if sources[i] == "" || (!strings.HasPrefix(sources[i], "ftp://") && !strings.HasPrefix(sources[i], "http://")) {
return sources, errors.New("Invalid sources")
}
}
return sources, nil
}
func DoDownload(dictName, service, version, depend, mode, source, deployPath string,
shardSeq int) (err error) {
sources, err := checkSources(source)
if err != nil {
logex.Warningf("checkSources %s return (%+v)", source, err)
return
}
downloadDirs, err := GetDownloadDirs(dictName, service, version, depend, deployPath, shardSeq,
len(sources))
if err != nil {
logex.Warningf("GetDownloadDirs %s return (%+v)", source, err)
return
}
version_suffix := ""
if service != "" {
version_suffix = version_suffix + "-" + dictName
}
if !checkToDownload(downloadDirs) {
cmd := fmt.Sprintf("cd %s/cube_data && echo %s > VERSION && cp VERSION VERSION-%s",
deployPath, getParentDir(version, depend)+version_suffix, dictName)
_, _, err = RetryCmd(cmd, RETRY_TIMES)
logex.Debugf("echo VERSION cmd:[%s] err:[%v]", cmd, err)
return
}
ch := make(chan error, len(sources))
wg := sync.WaitGroup{}
j := 0
numCo := 0
for ; j < len(sources); j++ {
if numCo >= MAX_DOWN_CO {
wg.Wait()
logex.Noticef("DoDownload co down.")
numCo = 0
}
numCo += 1
wg.Add(1)
time.Sleep(2000 * time.Millisecond)
timeOut := 900
if mode == "base" {
timeOut = 3600
}
go DoDownloadIndividual(sources[j], downloadDirs[j], (service != ""), timeOut, ch, &wg)
}
wg.Wait()
close(ch)
for err = range ch {
if err != nil {
return
}
}
cmd := fmt.Sprintf("cd %s/cube_data && echo %s > VERSION && cp VERSION VERSION-%s",
deployPath, getParentDir(version, depend)+version_suffix, dictName)
_, _, err = RetryCmd(cmd, RETRY_TIMES)
logex.Debugf("echo VERSION cmd:[%s] err:[%v]", cmd, err)
return
}
func FtpDownload(source string, dest string, timeOut int) (err error) {
dlCmd := fmt.Sprintf("wget --quiet --level=100 -P %s %s --limit-rate=10240k", dest, source)
fmt.Println(dlCmd)
_, _, err = RetryCmdWithSleep(dlCmd, RETRY_TIMES)
return
}
func checkToDownload(downloadDirs []string) bool {
for _, v := range downloadDirs {
if _, err := os.Stat(path.Join(v, DOWNLOAD_DONE_MARK_FILE)); err != nil {
logex.Noticef("check [%v] not exists.", v)
return true
}
}
return false
}
// simple hash
func getDownloadDisk(dictName string, shardSeq int) string {
index := len(dictName) * shardSeq % len(disks)
return disks[index]
}
func getParentDir(version string, depend string) (dir string) {
if version == depend {
dir = depend
} else {
dir = depend + "_" + version
}
return
}
func GetFileName(source string) (fileName string) {
s := strings.Split(source, "/")
fileName = s[len(s)-1]
return
}
func checkMd5(file string, fileMd5 string) (err error) {
cmd := fmt.Sprintf("md5sum %s | awk '{print $1}'", file)
stdout, _, _ := pipeline.Run(exec.Command("/bin/sh", "-c", cmd))
real_md5 := stdout.String()
cmd = fmt.Sprintf("cat %s | awk '{print $1}'", fileMd5)
stdout, _, _ = pipeline.Run(exec.Command("/bin/sh", "-c", cmd))
given_md5 := stdout.String()
if real_md5 != given_md5 {
logex.Warningf("checkMd5 failed real_md5[%s] given_md5[%s]", real_md5, given_md5)
err = errors.New("checkMd5 failed")
}
return
}
func RetryCmd(cmd string, retryTimes int) (stdoutStr string, stderrStr string, err error) {
for i := 0; i < retryTimes; i++ {
stdout, stderr, e := pipeline.Run(exec.Command("/bin/sh", "-c", cmd))
stdoutStr = stdout.String()
stderrStr = stderr.String()
err = e
logex.Debugf("cmd %s, stdout: %s, stderr: %s, err: (%+v)", cmd, stdoutStr, stderrStr, err)
if err == nil {
break
}
}
return
}
func RetryCmdWithSleep(cmd string, retryTimes int) (stdoutStr string, stderrStr string, err error) {
for i := 0; i < retryTimes; i++ {
stdout, stderr, e := pipeline.Run(exec.Command("/bin/sh", "-c", cmd))
stdoutStr = stdout.String()
stderrStr = stderr.String()
err = e
logex.Debugf("cmd %s, stdout: %s, stderr: %s, err: (%+v)", cmd, stdoutStr, stderrStr, err)
if err == nil {
break
}
time.Sleep(10000 * time.Millisecond)
}
return
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package agent
import (
"errors"
"fmt"
"sync"
"sync/atomic"
)
type (
workType struct {
poolWorker PoolWorker
resultChannel chan error
}
WorkPool struct {
queueChannel chan workType
workChannel chan PoolWorker
queuedWorkNum int32
activeWorkerNum int32
queueCapacity int32
workFilter sync.Map
}
)
type PoolWorker interface {
Token() string
DoWork()
}
func NewWorkPool(workerNum int, queueCapacity int32) *WorkPool {
workPool := WorkPool{
queueChannel: make(chan workType),
workChannel: make(chan PoolWorker, queueCapacity),
queuedWorkNum: 0,
activeWorkerNum: 0,
queueCapacity: queueCapacity,
}
for i := 0; i < workerNum; i++ {
go workPool.startWorkRoutine()
}
go workPool.startQueueRoutine()
return &workPool
}
func (workPool *WorkPool) startWorkRoutine() {
for {
select {
case work := <-workPool.workChannel:
workPool.doWork(work)
break
}
}
}
func (workPool *WorkPool) startQueueRoutine() {
for {
select {
case queueItem := <-workPool.queueChannel:
if atomic.AddInt32(&workPool.queuedWorkNum, 0) == workPool.queueCapacity {
queueItem.resultChannel <- fmt.Errorf("work pool fulled with %v pending works", QueueCapacity)
continue
}
atomic.AddInt32(&workPool.queuedWorkNum, 1)
workPool.workChannel <- queueItem.poolWorker
queueItem.resultChannel <- nil
break
}
}
}
func (workPool *WorkPool) doWork(poolWorker PoolWorker) {
defer atomic.AddInt32(&workPool.activeWorkerNum, -1)
defer workPool.workFilter.Delete(poolWorker.Token())
atomic.AddInt32(&workPool.queuedWorkNum, -1)
atomic.AddInt32(&workPool.activeWorkerNum, 1)
poolWorker.DoWork()
}
func (workPool *WorkPool) PostWorkWithToken(poolWorker PoolWorker) (err error) {
if _, ok := workPool.workFilter.Load(poolWorker.Token()); ok {
return errors.New("another work with same key is doing.")
}
workPool.workFilter.Store(poolWorker.Token(), true)
return workPool.PostWork(poolWorker)
}
func (workPool *WorkPool) PostWork(poolWorker PoolWorker) (err error) {
work := workType{poolWorker, make(chan error)}
defer close(work.resultChannel)
workPool.queueChannel <- work
err = <-work.resultChannel
return err
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"agent"
"fmt"
"github.com/Badangel/logex"
"github.com/docopt/docopt-go"
"os"
"path/filepath"
"runtime"
"strconv"
)
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
agent.Dir, _ = filepath.Abs(filepath.Dir(os.Args[0]))
usage := fmt.Sprintf(`Usage: ./m_master [options]
Options:
-n WORKERNUM set worker num.
-q QUEUENUM set queue num.
-P LISTEN_PORT agent listen port
Log options:
-l LOG_LEVEL set log level, values: 0,1,2,4,8,16. [default: 16]
--log_dir=DIR set log output dir. [default: ./log]
--log_name=NAME set log name. [default: m_agent]`, agent.Dir)
opts, err := docopt.Parse(usage, nil, true, "Cube Agent Checker 1.0.0", false)
if err != nil {
fmt.Println("ERROR:", err)
os.Exit(1)
}
log_level, _ := strconv.Atoi(opts["-l"].(string))
log_name := opts["--log_name"].(string)
log_dir := opts["--log_dir"].(string)
logex.SetLevel(getLogLevel(log_level))
if err := logex.SetUpFileLogger(log_dir, log_name, nil); err != nil {
fmt.Println("ERROR:", err)
}
logex.Notice("--- NEW SESSION -------------------------")
logex.Notice(">>> log_level:", log_level)
agent.WorkerNum = 10
if opts["-n"] != nil {
n, err := strconv.Atoi(opts["-n"].(string))
if err == nil {
agent.WorkerNum = n
}
}
agent.QueueCapacity = 20
if opts["-q"] != nil {
q, err := strconv.Atoi(opts["-q"].(string))
if err == nil {
agent.QueueCapacity = int32(q)
}
}
agent.CmdWorkPool = agent.NewWorkPool(agent.WorkerNum, agent.QueueCapacity)
if opts["-P"] == nil {
logex.Fatalf("ERROR: -P LISTEN PORT must be set!")
os.Exit(255)
}
agentPort := opts["-P"].(string)
logex.Notice(">>> starting server...")
addr := ":" + agentPort
if agent.StartHttp(addr) != nil {
logex.Noticef("cant start http(addr=%v). quit.", addr)
os.Exit(0)
}
}
func getLogLevel(log_level int) logex.Level {
switch log_level {
case 16:
return logex.DEBUG
case 8:
return logex.TRACE
case 4:
return logex.NOTICE
case 2:
return logex.WARNING
case 1:
return logex.FATAL
case 0:
return logex.NONE
}
return logex.DEBUG
}
......@@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "examples/demo-serving/op/general_infer_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_op.h"
#include "core/general-server/op/general_reader_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "examples/demo-serving/op/general_reader_op.h"
namespace baidu {
namespace paddle_serving {
......
......@@ -23,7 +23,7 @@
#else
#include "paddle_inference_api.h" // NOLINT
#endif
#include "examples/demo-serving/general_model_service.pb.h"
#include "core/general-server/general_model_service.pb.h"
namespace baidu {
namespace paddle_serving {
......
......@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "examples/demo-serving/op/general_reader_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_reader_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
......
......@@ -25,8 +25,8 @@
#endif
#include <string>
#include "core/predictor/framework/resource.h"
#include "examples/demo-serving/general_model_service.pb.h"
#include "examples/demo-serving/load_general_model_service.pb.h"
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/load_general_model_service.pb.h"
namespace baidu {
namespace paddle_serving {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "examples/demo-serving/op/general_infer_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "examples/demo-serving/op/general_reader_op.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
int GeneralInferOp::inference() {
const GeneralReaderOutput *reader_out =
get_depend_argument<GeneralReaderOutput>("general_reader_op");
if (!reader_out) {
LOG(ERROR) << "Failed mutable depended argument, op:"
<< "general_reader_op";
return -1;
}
int reader_status = reader_out->reader_status;
if (reader_status != 0) {
LOG(ERROR) << "Read request wrong.";
return -1;
}
const TensorVector *in = &reader_out->tensor_vector;
TensorVector *out = butil::get_object<TensorVector>();
int batch_size = (*in)[0].shape[0];
// infer
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1;
}
Response *res = mutable_data<Response>();
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts();
for (int j = 0; j < out->size(); ++j) {
Tensor *tensor = fetch_inst->add_tensor_array();
tensor->set_elem_type(1);
if (out->at(j).lod.size() == 1) {
tensor->add_shape(-1);
} else {
for (int k = 1; k < out->at(j).shape.size(); ++k) {
tensor->add_shape(out->at(j).shape[k]);
}
}
}
}
for (int i = 0; i < out->size(); ++i) {
float *data_ptr = static_cast<float *>(out->at(i).data.data());
int cap = 1;
for (int j = 1; j < out->at(i).shape.size(); ++j) {
cap *= out->at(i).shape[j];
}
if (out->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
for (int k = out->at(i).lod[0][j]; k < out->at(i).lod[0][j + 1]; k++) {
res->mutable_insts(j)->mutable_tensor_array(i)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
res->mutable_insts(j)->mutable_tensor_array(i)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
}
}
/*
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
in->clear();
butil::return_object<TensorVector>(in);
for (size_t i = 0; i < out->size(); ++i) {
(*out)[i].shape.clear();
}
out->clear();
butil::return_object<TensorVector>(out);
}
*/
return 0;
}
DEFINE_OP(GeneralInferOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle_inference_api.h" // NOLINT
#endif
#include "examples/demo-serving/general_model_service.pb.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
static const char* GENERAL_MODEL_NAME = "general_model";
class GeneralInferOp
: public baidu::paddle_serving::predictor::OpWithChannel<
baidu::paddle_serving::predictor::general_model::Response> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralInferOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "examples/demo-serving/op/general_reader_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int conf_check(const Request *req,
const std::shared_ptr<PaddleGeneralModelConfig> &model_config) {
int var_num = req->insts(0).tensor_array_size();
VLOG(2) << "var num: " << var_num;
if (var_num != model_config->_feed_type.size()) {
LOG(ERROR) << "feed var number not match.";
return -1;
}
VLOG(2) << "begin to checkout feed type";
for (int i = 0; i < var_num; ++i) {
VLOG(2) << "feed type[" << i << "]: " <<
model_config->_feed_type[i];
if (model_config->_feed_type[i] !=
req->insts(0).tensor_array(i).elem_type()) {
LOG(ERROR) << "feed type not match.";
return -1;
}
VLOG(2) << "feed shape size: " << model_config->_feed_shape[i].size();
if (model_config->_feed_shape[i].size() ==
req->insts(0).tensor_array(i).shape_size()) {
for (int j = 0; j < model_config->_feed_shape[i].size(); ++j) {
req->insts(0).tensor_array(i).shape(j);
if (model_config->_feed_shape[i][j] !=
req->insts(0).tensor_array(i).shape(j)) {
LOG(ERROR) << "feed shape not match.";
return -1;
}
}
} else {
LOG(ERROR) << "feed shape not match.";
return -1;
}
}
return 0;
}
int GeneralReaderOp::inference() {
// reade request from client
const Request *req = dynamic_cast<const Request *>(get_request_message());
int batch_size = req->insts_size();
int input_var_num = 0;
std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size;
std::vector<int64_t> capacity;
GeneralReaderOutput *res = mutable_data<GeneralReaderOutput>();
TensorVector *in = &res->tensor_vector;
if (!res) {
LOG(ERROR) << "Failed get op tls reader object output";
}
if (batch_size <= 0) {
res->reader_status = -1;
return 0;
}
int var_num = req->insts(0).tensor_array_size();
VLOG(2) << "var num: " << var_num;
// read config
LOG(INFO) << "start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
LOG(INFO) << "get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
LOG(INFO) << "print general model config done.";
// check
res->reader_status = conf_check(req, model_config);
if (res->reader_status != 0) {
LOG(INFO) << "model conf of server:";
resource.print_general_model_config(model_config);
return 0;
}
// package tensor
elem_type.resize(var_num);
elem_size.resize(var_num);
capacity.resize(var_num);
paddle::PaddleTensor lod_tensor;
for (int i = 0; i < var_num; ++i) {
elem_type[i] = req->insts(0).tensor_array(i).elem_type();
VLOG(2) << "var[" << i << "] has elem type: " << elem_type[i];
if (elem_type[i] == 0) { // int64
elem_size[i] = sizeof(int64_t);
lod_tensor.dtype = paddle::PaddleDType::INT64;
} else {
elem_size[i] = sizeof(float);
lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
}
if (req->insts(0).tensor_array(i).shape(0) == -1) {
lod_tensor.lod.resize(1);
lod_tensor.lod[0].push_back(0);
VLOG(2) << "var[" << i << "] is lod_tensor";
} else {
lod_tensor.shape.push_back(batch_size);
capacity[i] = 1;
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k);
VLOG(2) << "shape for var[" << i << "]: " << dim;
capacity[i] *= dim;
lod_tensor.shape.push_back(dim);
}
VLOG(2) << "var[" << i << "] is tensor, capacity: " << capacity[i];
}
if (i == 0) {
lod_tensor.name = "words";
} else {
lod_tensor.name = "label";
}
in->push_back(lod_tensor);
}
for (int i = 0; i < var_num; ++i) {
if (in->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = tensor.data_size();
VLOG(2) << "tensor size for var[" << i << "]: " << tensor.data_size();
int cur_len = in->at(i).lod[0].back();
VLOG(2) << "current len: " << cur_len;
in->at(i).lod[0].push_back(cur_len + data_len);
VLOG(2) << "new len: " << cur_len + data_len;
}
in->at(i).data.Resize(in->at(i).lod[0].back() * elem_size[i]);
in->at(i).shape = {in->at(i).lod[0].back(), 1};
VLOG(2) << "var[" << i
<< "] is lod_tensor and len=" << in->at(i).lod[0].back();
} else {
in->at(i).data.Resize(batch_size * capacity[i] * elem_size[i]);
VLOG(2) << "var[" << i
<< "] is tensor and capacity=" << batch_size * capacity[i];
}
}
for (int i = 0; i < var_num; ++i) {
if (elem_type[i] == 0) {
int64_t *dst_ptr = static_cast<int64_t *>(in->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) {
dst_ptr[offset + k] =
*(const int64_t *)req->insts(j).tensor_array(i).data(k).c_str();
}
if (in->at(i).lod.size() == 1) {
offset = in->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
}
} else {
float *dst_ptr = static_cast<float *>(in->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) {
dst_ptr[offset + k] =
*(const float *)req->insts(j).tensor_array(i).data(k).c_str();
}
if (in->at(i).lod.size() == 1) {
offset = in->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
}
}
}
VLOG(2) << "read data from client success";
// print request
std::ostringstream oss;
int64_t *example = reinterpret_cast<int64_t *>((*in)[0].data.data());
for (int i = 0; i < 10; i++) {
oss << *(example + i) << " ";
}
VLOG(2) << "head element of first feed var : " << oss.str();
//
return 0;
}
DEFINE_OP(GeneralReaderOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle_inference_api.h" // NOLINT
#endif
#include <string>
#include "core/predictor/framework/resource.h"
#include "examples/demo-serving/general_model_service.pb.h"
#include "examples/demo-serving/load_general_model_service.pb.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
struct GeneralReaderOutput {
std::vector<paddle::PaddleTensor> tensor_vector;
int reader_status = 0;
void Clear() {
size_t tensor_count = tensor_vector.size();
for (size_t ti = 0; ti < tensor_count; ++ti) {
tensor_vector[ti].shape.clear();
}
tensor_vector.clear();
}
std::string ShortDebugString() const { return "Not implemented!"; }
};
class GeneralReaderOp : public baidu::paddle_serving::predictor::OpWithChannel<
GeneralReaderOutput> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralReaderOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -141,6 +141,9 @@ class Client(object):
result = self.client_handle_.predict(
float_slot, float_feed_names, int_slot, int_feed_names, fetch_names)
# TODO(guru4elephant): the order of fetch var name should be consistent with
# general_model_config, this is not friendly
# In the future, we need make the number of fetched variable changable
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
......
......@@ -18,7 +18,7 @@ from paddle.fluid.framework import core
from paddle.fluid.framework import default_main_program
from paddle.fluid.framework import Program
from paddle.fluid import CPUPlace
from paddle.fluid.io import save_persistables
from paddle.fluid.io import save_inference_model
from ..proto import general_model_config_pb2 as model_conf
import os
......@@ -27,19 +27,13 @@ def save_model(server_model_folder,
feed_var_dict,
fetch_var_dict,
main_program=None):
if main_program is None:
main_program = default_main_program()
elif isinstance(main_program, CompiledProgram):
main_program = main_program._program
if main_program is None:
raise TypeError("program should be as Program type or None")
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
executor = Executor(place=CPUPlace())
save_persistables(executor, server_model_folder,
main_program)
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
target_vars = fetch_var_dict.values()
save_inference_model(server_model_folder, feed_var_names,
target_vars, executor, main_program=main_program)
config = model_conf.GeneralModelConfig()
......@@ -71,10 +65,11 @@ def save_model(server_model_folder,
config.fetch_var.extend([fetch_var])
cmd = "mkdir -p {}".format(client_config_folder)
os.system(cmd)
with open("{}/serving_client_conf.prototxt", "w") as fout:
with open("{}/serving_client_conf.prototxt".format(client_config_folder), "w") as fout:
fout.write(str(config))
with open("{}/serving_server_conf.prototxt", "w") as fout:
with open("{}/serving_server_conf.prototxt".format(server_model_folder), "w") as fout:
fout.write(str(config))
......
......@@ -175,7 +175,7 @@ class Server(object):
def run_server(self):
# just run server with system command
# currently we do not load cube
command = "/home/users/dongdaxiang/github_develop/Serving/build_server/core/general-server" \
command = "/home/users/dongdaxiang/github_develop/Serving/build_server/core/general-server/serving" \
" -enable_model_toolkit " \
"-inferservice_path {} " \
"-inferservice_file {} " \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册