提交 1ab5b8ef 编写于 作者: L liqingping

feat: init commit

# More info: https://docs.docker.com/engine/reference/builder/#dockerignore-file
# Ignore all files which are not go type
!**/*.go
!**/*.mod
!**/*.sum
name: Go
on:
- push
env:
VERSION: v0.0.1-alpha.0
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.16
- name: lint
shell: bash
run: |
# binary will be $(go env GOPATH)/bin/golangci-lint
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.41.1
# or install it into ./bin/
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.41.1
# In alpine linux (as it does not come with curl by default)
wget -O- -nv https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.41.1
golangci-lint --version
make lint
unit-test:
needs: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.16
- name: Test
shell: bash
run: |
# download etcd to bootstrap test environment
curl -L https://github.com/kubernetes-sigs/kubebuilder/releases/download/v2.3.2/kubebuilder_2.3.2_linux_amd64.tar.gz | tar -xz -C /tmp/
mv /tmp/kubebuilder_2.3.2_linux_amd64 /tmp/kubebuilder
export KUBEBUILDER_ASSETS=/tmp/kubebuilder/bin
make test
build:
needs: unit-test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.16
- name: Build Images
run: make dev-images
- name: Push Images
run: make docker-push
\ No newline at end of file
name: Publish Releases to Hub
# When its time to do a release do a full cross platform build for all supported
# architectures and push all of them to Docker Hub.
# Only trigger on semver shaped tags.
on:
push:
tags:
- "v*.*.*"
jobs:
docker:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [ linux/amd64 ]
target: [ di-operator, di-server ]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Prepare
id: prep
env:
DOCKERIO_ORG: diorchestrator
TARGET: ${{ matrix.target }}
run: |
DOCKER_IMAGE=$DOCKERIO_ORG/$TARGET
VERSION=edge
if [[ $GITHUB_REF == refs/tags/* ]]; then
VERSION=${GITHUB_REF#refs/tags/}
fi
if [ "${{ github.event_name }}" = "schedule" ]; then
VERSION=nightly
fi
TAGS="${DOCKER_IMAGE}:${VERSION}"
if [[ $VERSION =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then
TAGS="$TAGS,${DOCKER_IMAGE}:latest"
fi
echo ::set-output name=tags::${TAGS}
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
with:
platforms: all
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v1
- name: Cache Docker layers
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-
- name: Login to DockerHub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERIO_USERNAME }}
password: ${{ secrets.DOCKERIO_PASSWORD }}
- name: Build and push
id: docker_build
uses: docker/build-push-action@v2
env:
TARGET: ${{ matrix.target }}
with:
builder: ${{ steps.buildx.outputs.name }}
context: ./
file: ./Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.prep.outputs.tags }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache
target: $TARGET
- name: Image digest
run: echo ${{ steps.docker_build.outputs.digest }}
\ No newline at end of file
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
bin
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Kubernetes Generated files - skip generated files, except for vendored files
!vendor/**/zz_generated.*
# editor and IDE paraphernalia
.idea
*.swp
*.swo
*~
*.vscode
*.coverprofile
coverage.out.*
config/webhook/manifests.yaml
\ No newline at end of file
run:
# default concurrency is a available CPU number
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 10m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1
# include test files or not, default is true
tests: true
skip-dirs:
- manifests # deploy phoenix-rubber yaml
- third_party # from go-ethereum
- _out #phoenix-rubber executable binary file
- doc # user tutorial
- deployment # deploy phoenix-rubber yaml
- config # the crd config yaml
- cluster # the logging bash
- vendor # the third library
- api # auto-generated
- pkg/client # auto-generated
- example
- bin
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
format: colored-line-number
# print lines of code with issue, default is true
print-issued-lines: true
# print linter name in the end of issue text, default is true
print-linter-name: true
linters:
fast: true
enable:
- gofmt
- goimports
- golint
- deadcode
disable:
- gocyclo
- typecheck
- bodyclose
- gochecknoinits
- gochecknoglobals
- gocyclo
- lll
- maligned
- unparam
- unused
- depguard
- dupl
- errcheck
- gas
- goconst
- gocritic
- gosec
- gosimple
- govet
- interfacer
- ineffassign
- megacheck
- misspell
- nakedret
- prealloc
- staticcheck
- structcheck
- stylecheck
- typecheck
- unconvert
- varcheck
# Build the di-operator binary
FROM golang:1.14 as builder
WORKDIR /workspace
# Copy the Go Modules manifests
COPY go.mod go.mod
COPY go.sum go.sum
# cache deps before building and copying source so that we don't need to re-download as much
# and so that source changes don't invalidate our downloaded layer
ARG GOPROXY=https://goproxy.cn
RUN go mod download
# Copy the go source
COPY main.go main.go
COPY api/ api/
COPY common/ common/
COPY controllers/ controllers/
COPY server/ server/
COPY utils/ utils/
# Build
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on go build -a -o di-operator main.go
# Build
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on go build -a -o di-server server/main.go
# Use distroless as minimal base image to package the di-operator binary
# Refer to https://github.com/GoogleContainerTools/distroless for more details
FROM redhat/ubi8:latest as di-operator
LABEL maintainer="opendilab.contact.gmail.com"
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
WORKDIR /
COPY --from=builder /workspace/di-operator .
ENTRYPOINT ["/di-operator"]
FROM redhat/ubi8:latest as di-server
LABEL maintainer="opendilab.contact.gmail.com"
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
WORKDIR /
COPY --from=builder /workspace/di-server .
ENTRYPOINT ["/di-server"]
\ No newline at end of file
# di-operator version
VERSION ?= v0.1.0
MASTER_VERSION := $(VERSION)
COMMIT_SHORT_SHA=$(shell git log -n 1 | head -n 1 | sed -e 's/^commit //' | head -c 8)
VERSION := $(VERSION)-${COMMIT_SHORT_SHA}
ifeq ($(GIT_BRANCH),master)
VERSION := $(MASTER_VERSION)
endif
ifneq ($(findstring release,$(GIT_BRANCH)),)
VERSION := $(MASTER_VERSION)
endif
# Image URL to use all building/pushing image targets
IMG_BASE ?= diorchestrator/di-operator
SERVER_IMG_BASE ?= diorchestrator/di-server
IMG ?= ${IMG_BASE}:${VERSION}
MASTER_IMG ?= ${IMG_BASE}:${MASTER_VERSION}
SERVER_IMG ?= ${SERVER_IMG_BASE}:${VERSION}
MASTER_SERVER_IMG ?= ${SERVER_IMG_BASE}:${MASTER_VERSION}
# Produce CRDs that work back to Kubernetes 1.11 (no version conversion)
CRD_OPTIONS ?= "crd:trivialVersions=true,preserveUnknownFields=false"
# Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set)
ifeq (,$(shell go env GOBIN))
GOBIN=$(shell go env GOPATH)/bin
else
GOBIN=$(shell go env GOBIN)
endif
all: build
##@ General
# The help target prints out all targets with their descriptions organized
# beneath their categories. The categories are represented by '##@' and the
# target descriptions by '##'. The awk commands is responsible for reading the
# entire set of makefiles included in this invocation, looking for lines of the
# file as xyz: ## something, and then pretty-format the target and help. Then,
# if there's a line with ##@ something, that gets pretty-printed as a category.
# More info on the usage of ANSI control characters for terminal formatting:
# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters
# More info on the awk command:
# http://linuxcommand.org/lc3_adv_awk.php
help: ## Display this help.
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
##@ Development
manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects.
$(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=di-operator-cluster-role webhook paths="./..." output:crd:artifacts:config=config/crd/bases
cd config/manager && $(KUSTOMIZE) edit set image ${IMG_BASE}=${MASTER_IMG} ${SERVER_IMG_BASE}=${MASTER_SERVER_IMG}
./hack/update-image-tags.sh config/manager ${MASTER_VERSION}
# dev-manifests will add COMMIT_SHORT_SHA to ci version, and image tag, so it is only used for development
# used `make manifests` when commited git
dev-manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects.
$(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=di-operator-cluster-role webhook paths="./..." output:crd:artifacts:config=config/crd/bases
cd config/manager && $(KUSTOMIZE) edit set image ${IMG_BASE}=${IMG} ${SERVER_IMG_BASE}=${SERVER_IMG}
./hack/update-image-tags.sh config/manager ${VERSION}
generate: controller-gen ## Generate code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations.
$(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths="./..."
fmt: ## Run go fmt against code.
go fmt ./...
vet: ## Run go vet against code.
go vet ./...
# Run golangci-lint
lint:
golangci-lint run -v --timeout=5m
.PHONY: test
test: ginkgo ## Run tests.
$(GINKGO) -nodes 4 -v -cover -coverprofile=coverage.out ./...
##@ Build
build: generate ## Build di-operator binary.
go build -o bin/di-operator main.go
go build -o bin/di-server server/main.go
run: manifests generate fmt vet ## Run a controller from your host.
go run ./main.go
docker-build: ## Build docker image with the di-operator.
docker build -t ${IMG} --target di-operator .
docker build -t ${SERVER_IMG} --target di-server .
docker-push: ## Push docker image with the di-operator.
docker push ${IMG}
docker push ${SERVER_IMG}
docker-release: ## Release docker image with the di-operator.
docker tag ${IMG} ${MASTER_IMG}
docker tag ${SERVER_IMG} ${MASTER_SERVER_IMG}
docker push ${MASTER_IMG}
docker push ${MASTER_SERVER_IMG}
##@ Deployment
install: manifests kustomize ## Install CRDs into the K8s cluster specified in ~/.kube/config.
$(KUSTOMIZE) build config/crd | kubectl apply -f -
uninstall: manifests kustomize ## Uninstall CRDs from the K8s cluster specified in ~/.kube/config.
$(KUSTOMIZE) build config/crd | kubectl delete -f -
deploy: manifests kustomize ## Deploy controller to the K8s cluster specified in ~/.kube/config.
$(KUSTOMIZE) build config/default | kubectl apply -f -
dev-deploy: dev-manifests kustomize ## Deploy controller to the K8s cluster specified in ~/.kube/config.
$(KUSTOMIZE) build config/default | kubectl apply -f -
installer-gen: manifests kustomize ## generate di-manager.yaml
$(KUSTOMIZE) build config/default > config/di-manager.yaml
undeploy: ## Undeploy controller from the K8s cluster specified in ~/.kube/config.
$(KUSTOMIZE) build config/default | kubectl delete -f -
dev-undeploy: ## Undeploy controller from the K8s cluster specified in ~/.kube/config.
$(KUSTOMIZE) build config/default | kubectl delete -f -
CONTROLLER_GEN = $(shell pwd)/bin/controller-gen
controller-gen: ## Download controller-gen locally if necessary.
$(call go-get-tool,$(CONTROLLER_GEN),sigs.k8s.io/controller-tools/cmd/controller-gen@v0.4.1)
KUSTOMIZE = $(shell pwd)/bin/kustomize
kustomize: ## Download kustomize locally if necessary.
$(call go-get-tool,$(KUSTOMIZE),sigs.k8s.io/kustomize/kustomize/v3@v3.8.7)
GINKGO = $(shell pwd)/bin/ginkgo
ginkgo: ## Download ginkgo locally if necessary.
$(call go-get-tool,$(GINKGO),github.com/onsi/ginkgo/ginkgo@v1.14.1)
# go-get-tool will 'go get' any package $2 and install it to $1.
PROJECT_DIR := $(shell dirname $(abspath $(lastword $(MAKEFILE_LIST))))
define go-get-tool
@[ -f $(1) ] || { \
set -e ;\
TMP_DIR=$$(mktemp -d) ;\
cd $$TMP_DIR ;\
go mod init tmp ;\
echo "Downloading $(2)" ;\
GOBIN=$(PROJECT_DIR)/bin go get $(2) ;\
rm -rf $$TMP_DIR ;\
}
endef
domain: opendilab.org
layout:
- go.kubebuilder.io/v3
projectName: di
repo: opendilab.org/di-orchestrator
resources:
- api:
crdVersion: v1
namespaced: true
controller: true
domain: opendilab.org
group: diengine
kind: DIJob
path: opendilab.org/di-orchestrator/api/v1alpha1
version: v1alpha1
webhooks:
defaulting: true
validation: true
webhookVersion: v1
- api:
crdVersion: v1
namespaced: true
domain: opendilab.org
group: diengine
kind: AggregatorConfig
path: opendilab.org/di-operator/api/v1alpha1
version: v1alpha1
version: "3"
# DI Orchestrator
DI Orchestrator is designed to manage DI (Decision Intelligence) jobs using Kubernetes Custom Resource and Operator.
### Prerequisites
- a well prepared kubernetes cluster. Follow the [instructions](https://kubernetes.io/docs/setup/production-environment/tools/kubeadm/create-cluster-kubeadm/) to create a kubernetes cluster, or create a local kubernetes node referring to [kind](https://kind.sigs.k8s.io/docs/user/quick-start/) or [minikube](https://minikube.sigs.k8s.io/docs/start/)
- cert-manager. Installation on kubernetes referenced to [cert-manager docs](https://cert-manager.io/docs/installation/kubernetes/). Or you can install by the following command.
```bash
kubectl create -f ./config/certmanager/cert-manager.yaml
```
### Install DI Orchestrator
Install `di-operator` and `di-server` with the following command.
```bash
kubectl create -f ./config/di-manager.yaml
```
`di-operator` and `di-server` will be installed in `di-system` namespace.
```bash
$ kubectl get pod -n -system
NAME READY STATUS RESTARTS AGE
di-operator-57cc65d5c9-5vnvn 1/1 Running 0 59s
di-server-7b86ff8df4-jfgmp 1/1 Running 0 59s
```
Install global components of DIJob defined in AggregatorConfig:
```bash
kubectl create -f examples/di_v1alpha1_agconfig.yaml -n di-system
```
### Submit DIJob
```bash
# submit DIJob
$ kubectl create -f examples/di_v1alpha1_dijob.yaml
# get pod and you will see coordinator is created by di-operator
# few seconds later, you will see collectors and learners created by di-server
$ kubectl get pod
# get logs of coordinator
$ kubectl logs dijob-example-coordinator
```
## User Guide
Refers to [user-guide](./docs/architecture.md). For Chinese version, please refer to [中文手册](./docs/architecture-cn.md)
## Contributing
Refers to [developer-guide](./docs/developer-guide.md). Contact us throw <opendilab.contact@gmail.com>
\ No newline at end of file
/*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package v1alpha1
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
// NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized.
// AggregatorConfigSpec defines the desired state of AggregatorConfig
type AggregatorConfigSpec struct {
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
// Important: Run "make" to regenerate code after modifying this file
Aggregator AggregatorSpec `json:"aggregator,"`
}
//
type AggregatorSpec struct {
Template corev1.PodTemplateSpec `json:"template,"`
}
// AggregatorConfigStatus defines the observed state of AggregatorConfig
type AggregatorConfigStatus struct {
// INSERT ADDITIONAL STATUS FIELD - define observed state of cluster
// Important: Run "make" to regenerate code after modifying this file
Actors *AggregatorReplicaStatus `json:"actors,omitempty"`
Learners *AggregatorReplicaStatus `json:"learners,omitempty"`
}
// AggregatorReplicaStatus defines the observed state of actors' and learners' replicas
type AggregatorReplicaStatus struct {
Total int32 `json:"total,omitempty"`
Active int32 `json:"active,omitempty"`
}
// +kubebuilder:object:root=true
// +kubebuilder:subresource:status
// +kubebuilder:resource:shortName=agconfig
// +kubebuilder:printcolumn:name="Age",type=date,JSONPath=`.metadata.creationTimestamp`
// AggregatorConfig is the Schema for the AggregatorConfigs API
type AggregatorConfig struct {
metav1.TypeMeta `json:",inline"`
metav1.ObjectMeta `json:"metadata,omitempty"`
Spec AggregatorConfigSpec `json:"spec,omitempty"`
// Status AggregatorConfigStatus `json:"status,omitempty"`
}
// +kubebuilder:object:root=true
// AggregatorConfigList contains a list of AggregatorConfig
type AggregatorConfigList struct {
metav1.TypeMeta `json:",inline"`
metav1.ListMeta `json:"metadata,omitempty"`
Items []AggregatorConfig `json:"items"`
}
func init() {
SchemeBuilder.Register(&AggregatorConfig{}, &AggregatorConfigList{})
}
/*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package v1alpha1
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
// NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized.
// DIJobSpec defines the desired state of DIJob
type DIJobSpec struct {
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
// Important: Run "make" to regenerate code after modifying this file
// Group is a collection of DIJobs
Group string `json:"group,omitempty"`
//Priority labels the priority of DIJob
PriorityClassName PriorityClassName `json:"priorityClassName,omitempty"`
// CleanPodPolicy defines the policy to clean pods after DIJob completed
CleanPodPolicy CleanPodPolicy `json:"cleanPodPolicy,omitempty"`
// Volumes defines the shared volumes for di components
Volumes []corev1.Volume `json:"volumes,omitempty"`
Coordinator CoordinatorSpec `json:"coordinator"`
Collector CollectorSpec `json:"collector,"`
Learner LearnerSpec `json:"learner,"`
}
// Priority defines the priority of DIJob
type PriorityClassName string
const (
// NormalPriority is normal priority
NormalPriority PriorityClassName = "default"
// HighPriority is high priority
HighPriority PriorityClassName = "high"
)
type CleanPodPolicy string
const (
// CleanPodPolicyRunning means deleting all running pods of the job after completed
CleanPodPolicyRunning CleanPodPolicy = "Running"
// CleanPodPolicyALL means deleting all pods of the job after completed
CleanPodPolicyALL CleanPodPolicy = "ALL"
// CleanPodPolicyNone means never deleting any pods of the job after completed
CleanPodPolicyNone CleanPodPolicy = "None"
)
// CoordinatorSpec defines the desired state of coordinators
type CoordinatorSpec struct {
Template corev1.PodTemplateSpec `json:"template"`
}
// CollectorSpec defines the desired state of CollectorSpec
type CollectorSpec struct {
Template corev1.PodTemplateSpec `json:"template,"`
}
// Learner defines the desired state of Learner
type LearnerSpec struct {
Template corev1.PodTemplateSpec `json:"template,"`
}
// DIJobStatus defines the observed state of DIJob
type DIJobStatus struct {
// INSERT ADDITIONAL STATUS FIELD - define observed state of cluster
// Important: Run "make" to regenerate code after modifying this file
Phase Phase `json:"phase,omitempty"`
Conditions []DIJobCondition `json:"conditions,omitempty"`
ReplicaStatus map[ReplicaType]*ReplicaStatus `json:"replicaStatus,omitempty"`
}
// Phase defines the phase of DIJob
type Phase string
const (
// JobCreated means the job has been submitted to the cluster,
// but not all the pods and services have been created,
// or not pods are running
JobCreated Phase = "Created"
// JobRunning means all the pods are in running state
JobRunning Phase = "Running"
// JobSucceeded means job completed without error
JobSucceeded Phase = "Succeeded"
// JobFailed means some pods failed, job is also considered failed
JobFailed Phase = "Failed"
// JobUnknown means the job is in unknown state
JobUnknown Phase = "Unknown"
)
// ReplicaType represents the type of the replica. Each operator needs to define its
// own set of ReplicaTypes.
type ReplicaType string
const (
ReplicaTypeCollector ReplicaType = "Collector"
ReplicaTypeLearner ReplicaType = "Learner"
ReplicaTypeAggregator ReplicaType = "Aggregator"
ReplicaTypeCoordinator ReplicaType = "Coordinator"
)
// ReplicaStatus represents the current observed state of the replica.
type ReplicaStatus struct {
// The number of actively running pods.
Active int32 `json:"active,omitempty"`
// The number of pods which reached phase Succeeded.
Succeeded int32 `json:"succeeded,omitempty"`
// The number of pods which reached phase Failed.
Failed int32 `json:"failed,omitempty"`
}
// DIJobCondition records the conditions of DIJob
type DIJobCondition struct {
// Type of job condition.
Type Phase `json:"type"`
// Status of the condition, one of True, False, Unknown.
Status corev1.ConditionStatus `json:"status"`
// The reason for the condition's last transition.
Reason string `json:"reason,omitempty"`
// A human readable message indicating details about the transition.
Message string `json:"message,omitempty"`
// The last time this condition was updated.
LastUpdateTime metav1.Time `json:"lastUpdateTime,omitempty"`
// Last time the condition transitioned from one status to another.
LastTransitionTime metav1.Time `json:"lastTransitionTime,omitempty"`
}
// +kubebuilder:object:root=true
// +kubebuilder:subresource:status
// +kubebuilder:resource:shortName=dijob
// +kubebuilder:printcolumn:name="Phase",type=string,JSONPath=`.status.phase`
// +kubebuilder:printcolumn:name="Age",type=date,JSONPath=`.metadata.creationTimestamp`
// DIJob is the Schema for the dijobs API
type DIJob struct {
metav1.TypeMeta `json:",inline"`
metav1.ObjectMeta `json:"metadata,omitempty"`
Spec DIJobSpec `json:"spec,omitempty"`
Status DIJobStatus `json:"status,omitempty"`
}
// +kubebuilder:object:root=true
// DIJobList contains a list of DIJob
type DIJobList struct {
metav1.TypeMeta `json:",inline"`
metav1.ListMeta `json:"metadata,omitempty"`
Items []DIJob `json:"items"`
}
func init() {
SchemeBuilder.Register(&DIJob{}, &DIJobList{})
}
/*
Copyright 2021 The OpenDILab authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package v1alpha1
import (
"fmt"
"k8s.io/apimachinery/pkg/runtime"
ctrl "sigs.k8s.io/controller-runtime"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/webhook"
)
// log is for logging in this package.
var dijoblog = logf.Log.WithName("dijob-resource")
func (r *DIJob) SetupWebhookWithManager(mgr ctrl.Manager) error {
return ctrl.NewWebhookManagedBy(mgr).
For(r).
Complete()
}
// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
//+kubebuilder:webhook:path=/mutate-diengine-opendilab-org-v1alpha1-dijob,mutating=true,failurePolicy=fail,sideEffects=None,groups=diengine.opendilab.org,resources=dijobs,verbs=create;update,versions=v1alpha1,name=mdijob.kb.io,admissionReviewVersions={v1,v1beta1}
var _ webhook.Defaulter = &DIJob{}
// Default implements webhook.Defaulter so a webhook will be registered for the type
func (r *DIJob) Default() {
dijoblog.Info("default", "name", r.Name)
if r.Spec.CleanPodPolicy == "" {
r.Spec.CleanPodPolicy = CleanPodPolicyRunning
}
}
// TODO(user): change verbs to "verbs=create;update;delete" if you want to enable deletion validation.
//+kubebuilder:webhook:path=/validate-diengine-opendilab-org-v1alpha1-dijob,mutating=false,failurePolicy=fail,sideEffects=None,groups=diengine.opendilab.org,resources=dijobs,verbs=create;update,versions=v1alpha1,name=vdijob.kb.io,admissionReviewVersions={v1,v1beta1}
var _ webhook.Validator = &DIJob{}
// ValidateCreate implements webhook.Validator so a webhook will be registered for the type
func (r *DIJob) ValidateCreate() error {
dijoblog.Info("validate create", "name", r.Name)
// TODO(user): fill in your validation logic upon object creation.
if r.Spec.CleanPodPolicy != CleanPodPolicyALL && r.Spec.CleanPodPolicy != CleanPodPolicyNone &&
r.Spec.CleanPodPolicy != CleanPodPolicyRunning {
return fmt.Errorf("Invalid CleanPodPolicy %s, expected in [%s, %s, %s]",
r.Spec.CleanPodPolicy, CleanPodPolicyNone, CleanPodPolicyRunning, CleanPodPolicyALL)
}
return nil
}
// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type
func (r *DIJob) ValidateUpdate(old runtime.Object) error {
dijoblog.Info("validate update", "name", r.Name)
// TODO(user): fill in your validation logic upon object update.
if r.Spec.CleanPodPolicy != CleanPodPolicyALL && r.Spec.CleanPodPolicy != CleanPodPolicyNone &&
r.Spec.CleanPodPolicy != CleanPodPolicyRunning {
return fmt.Errorf("Invalid CleanPodPolicy %s, expected in [%s, %s, %s]",
r.Spec.CleanPodPolicy, CleanPodPolicyNone, CleanPodPolicyRunning, CleanPodPolicyALL)
}
return nil
}
// ValidateDelete implements webhook.Validator so a webhook will be registered for the type
func (r *DIJob) ValidateDelete() error {
dijoblog.Info("validate delete", "name", r.Name)
// TODO(user): fill in your validation logic upon object deletion.
return nil
}
/*
Copyright 2021 The OpenDILab authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package v1alpha1 contains API Schema definitions for the v1alpha1 API group
//+kubebuilder:object:generate=true
//+groupName=diengine.opendilab.org
package v1alpha1
import (
"k8s.io/apimachinery/pkg/runtime/schema"
"sigs.k8s.io/controller-runtime/pkg/scheme"
)
var (
// KindDIJob is kind of DIJob
KindDIJob = "DIJob"
// KindAGConfig is kind of AGConfig
KindAGConfig = "AggregatorConfig"
// GroupVersion is group version used to register these objects
GroupVersion = schema.GroupVersion{Group: "diengine.opendilab.org", Version: "v1alpha1"}
// SchemeBuilder is used to add go types to the GroupVersionKind scheme
SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion}
// AddToScheme adds the types in this group-version to the given scheme.
AddToScheme = SchemeBuilder.AddToScheme
)
/*
Copyright 2021 The OpenDILab authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package v1alpha1
import (
"context"
"crypto/tls"
"fmt"
"net"
"path/filepath"
"testing"
"time"
. "github.com/onsi/ginkgo"
"github.com/onsi/ginkgo/config"
. "github.com/onsi/gomega"
admissionv1beta1 "k8s.io/api/admission/v1beta1"
//+kubebuilder:scaffold:imports
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/rest"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/envtest"
"sigs.k8s.io/controller-runtime/pkg/envtest/printer"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
)
// These tests use Ginkgo (BDD-style Go testing framework). Refer to
// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.
var cfg *rest.Config
var k8sClient client.Client
var testEnv *envtest.Environment
var ctx context.Context
var cancel context.CancelFunc
func TestAPIs(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecsWithDefaultAndCustomReporters(t,
"Webhook Suite",
[]Reporter{printer.NewlineReporter{}})
}
var _ = BeforeSuite(func() {
logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)))
ctx, cancel = context.WithCancel(context.TODO())
By("bootstrapping test environment")
testEnv = &envtest.Environment{
CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")},
ErrorIfCRDPathMissing: false,
WebhookInstallOptions: envtest.WebhookInstallOptions{
Paths: []string{filepath.Join("..", "..", "config", "webhook")},
LocalServingPort: 8100 + config.GinkgoConfig.ParallelNode,
},
}
cfg, err := testEnv.Start()
Expect(err).NotTo(HaveOccurred())
Expect(cfg).NotTo(BeNil())
scheme := runtime.NewScheme()
err = AddToScheme(scheme)
Expect(err).NotTo(HaveOccurred())
err = admissionv1beta1.AddToScheme(scheme)
Expect(err).NotTo(HaveOccurred())
//+kubebuilder:scaffold:scheme
k8sClient, err = client.New(cfg, client.Options{Scheme: scheme})
Expect(err).NotTo(HaveOccurred())
Expect(k8sClient).NotTo(BeNil())
// start webhook server using Manager
webhookInstallOptions := &testEnv.WebhookInstallOptions
mgr, err := ctrl.NewManager(cfg, ctrl.Options{
Scheme: scheme,
Host: webhookInstallOptions.LocalServingHost,
Port: webhookInstallOptions.LocalServingPort,
CertDir: webhookInstallOptions.LocalServingCertDir,
LeaderElection: false,
MetricsBindAddress: "0",
})
Expect(err).NotTo(HaveOccurred())
err = (&DIJob{}).SetupWebhookWithManager(mgr)
Expect(err).NotTo(HaveOccurred())
//+kubebuilder:scaffold:webhook
go func() {
err = mgr.Start(ctx)
if err != nil {
Expect(err).NotTo(HaveOccurred())
}
}()
// wait for the webhook server to get ready
dialer := &net.Dialer{Timeout: time.Second}
addrPort := fmt.Sprintf("%s:%d", webhookInstallOptions.LocalServingHost, webhookInstallOptions.LocalServingPort)
Eventually(func() error {
conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true})
if err != nil {
return err
}
conn.Close()
return nil
}).Should(Succeed())
}, 60)
var _ = AfterSuite(func() {
cancel()
By("tearing down the test environment")
err := testEnv.Stop()
Expect(err).NotTo(HaveOccurred())
})
package v1alpha1
import (
"context"
"fmt"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
utilrand "k8s.io/apimachinery/pkg/util/rand"
"sigs.k8s.io/controller-runtime/pkg/client"
)
var _ = Describe("Webhook test", func() {
Context("When creating a DIJob", func() {
It("Should be validated by webhook before creating", func() {
type testCase struct {
cleanPodPolicy CleanPodPolicy
expectCleanPodPolicy CleanPodPolicy
}
testCases := []testCase{
{cleanPodPolicy: CleanPodPolicyRunning, expectCleanPodPolicy: CleanPodPolicyRunning},
{cleanPodPolicy: CleanPodPolicyALL, expectCleanPodPolicy: CleanPodPolicyALL},
{cleanPodPolicy: CleanPodPolicyNone, expectCleanPodPolicy: CleanPodPolicyNone},
{cleanPodPolicy: CleanPodPolicy(""), expectCleanPodPolicy: CleanPodPolicyRunning},
{cleanPodPolicy: CleanPodPolicy("hello"), expectCleanPodPolicy: CleanPodPolicy("will be refused by webhook")},
{cleanPodPolicy: CleanPodPolicy("sdft"), expectCleanPodPolicy: CleanPodPolicy("will be refused by webhook")},
}
for i := range testCases {
c := testCases[i]
job := NewDIJob()
name := GenerateName(job.Name)
job.SetName(name)
job.Spec.CleanPodPolicy = c.cleanPodPolicy
var err error
ctx := context.Background()
err = k8sClient.Create(ctx, job, &client.CreateOptions{})
if err != nil {
if c.cleanPodPolicy != CleanPodPolicyRunning && c.cleanPodPolicy != CleanPodPolicyNone &&
c.cleanPodPolicy != CleanPodPolicyALL {
Expect(err.Error()).To(ContainSubstring("Invalid CleanPodPolicy"))
continue
} else {
Expect(err).NotTo(HaveOccurred())
}
}
cjob := DIJob{}
jobKey := types.NamespacedName{Namespace: job.Namespace, Name: job.Name}
Eventually(func() bool {
err = k8sClient.Get(ctx, jobKey, &cjob)
if err != nil {
return false
}
return cjob.Spec.CleanPodPolicy == c.expectCleanPodPolicy
}, timeout, interval).Should(BeTrue())
}
})
It("Should be validated by webhook before updating", func() {
type testCase struct {
cleanPodPolicy CleanPodPolicy
expectCleanPodPolicy CleanPodPolicy
}
testCases := []testCase{
{cleanPodPolicy: CleanPodPolicyRunning, expectCleanPodPolicy: CleanPodPolicyRunning},
{cleanPodPolicy: CleanPodPolicyALL, expectCleanPodPolicy: CleanPodPolicyALL},
{cleanPodPolicy: CleanPodPolicyNone, expectCleanPodPolicy: CleanPodPolicyNone},
{cleanPodPolicy: CleanPodPolicy(""), expectCleanPodPolicy: CleanPodPolicyRunning},
{cleanPodPolicy: CleanPodPolicy("hello"), expectCleanPodPolicy: CleanPodPolicy("will be refused by webhook")},
{cleanPodPolicy: CleanPodPolicy("sdft"), expectCleanPodPolicy: CleanPodPolicy("will be refused by webhook")},
}
for i := range testCases {
c := testCases[i]
job := NewDIJob()
name := GenerateName(job.Name)
job.SetName(name)
var err error
ctx := context.Background()
err = k8sClient.Create(ctx, job, &client.CreateOptions{})
Expect(err).NotTo(HaveOccurred())
job.Spec.CleanPodPolicy = c.cleanPodPolicy
err = k8sClient.Update(ctx, job, &client.UpdateOptions{})
if err != nil {
if c.cleanPodPolicy != CleanPodPolicyRunning && c.cleanPodPolicy != CleanPodPolicyNone &&
c.cleanPodPolicy != CleanPodPolicyALL {
Expect(err.Error()).To(ContainSubstring("Invalid CleanPodPolicy"))
continue
} else {
Expect(err).NotTo(HaveOccurred())
}
}
cjob := DIJob{}
jobKey := types.NamespacedName{Namespace: job.Namespace, Name: job.Name}
Eventually(func() CleanPodPolicy {
err = k8sClient.Get(ctx, jobKey, &cjob)
if err != nil {
return CleanPodPolicy(err.Error())
}
return cjob.Spec.CleanPodPolicy
}, timeout, interval).Should(Equal(c.expectCleanPodPolicy))
}
})
})
})
const (
randomLength = 5
DIJobName = "dijob-example"
DIJobNamespace = "default"
DIJobImage = "alpine:latest"
DefaultSleepDuration = "5s"
timeout = 5 * time.Second
interval = 250 * time.Millisecond
)
func NewDIJob() *DIJob {
return &DIJob{
TypeMeta: metav1.TypeMeta{
Kind: KindDIJob,
APIVersion: GroupVersion.String(),
},
ObjectMeta: metav1.ObjectMeta{
Name: DIJobName,
Namespace: DIJobNamespace,
},
Spec: DIJobSpec{
Coordinator: CoordinatorSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "coordinator",
Image: DIJobImage,
Command: []string{"/bin/sh", "-c", "sleep", DefaultSleepDuration},
},
},
},
},
},
Collector: CollectorSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "collector",
Image: DIJobImage,
Command: []string{"/bin/sh", "-c", "sleep", DefaultSleepDuration},
},
},
},
},
},
Learner: LearnerSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "learner",
Image: DIJobImage,
Command: []string{"/bin/sh", "-c", "sleep", DefaultSleepDuration},
},
},
},
},
},
},
}
}
func GenerateName(name string) string {
return fmt.Sprintf("%s-%s", name, utilrand.String(randomLength))
}
// +build !ignore_autogenerated
/*
Copyright 2021 The OpenDILab authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Code generated by controller-gen. DO NOT EDIT.
package v1alpha1
import (
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
)
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AggregatorConfig) DeepCopyInto(out *AggregatorConfig) {
*out = *in
out.TypeMeta = in.TypeMeta
in.ObjectMeta.DeepCopyInto(&out.ObjectMeta)
in.Spec.DeepCopyInto(&out.Spec)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AggregatorConfig.
func (in *AggregatorConfig) DeepCopy() *AggregatorConfig {
if in == nil {
return nil
}
out := new(AggregatorConfig)
in.DeepCopyInto(out)
return out
}
// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
func (in *AggregatorConfig) DeepCopyObject() runtime.Object {
if c := in.DeepCopy(); c != nil {
return c
}
return nil
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AggregatorConfigList) DeepCopyInto(out *AggregatorConfigList) {
*out = *in
out.TypeMeta = in.TypeMeta
in.ListMeta.DeepCopyInto(&out.ListMeta)
if in.Items != nil {
in, out := &in.Items, &out.Items
*out = make([]AggregatorConfig, len(*in))
for i := range *in {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AggregatorConfigList.
func (in *AggregatorConfigList) DeepCopy() *AggregatorConfigList {
if in == nil {
return nil
}
out := new(AggregatorConfigList)
in.DeepCopyInto(out)
return out
}
// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
func (in *AggregatorConfigList) DeepCopyObject() runtime.Object {
if c := in.DeepCopy(); c != nil {
return c
}
return nil
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AggregatorConfigSpec) DeepCopyInto(out *AggregatorConfigSpec) {
*out = *in
in.Aggregator.DeepCopyInto(&out.Aggregator)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AggregatorConfigSpec.
func (in *AggregatorConfigSpec) DeepCopy() *AggregatorConfigSpec {
if in == nil {
return nil
}
out := new(AggregatorConfigSpec)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AggregatorConfigStatus) DeepCopyInto(out *AggregatorConfigStatus) {
*out = *in
if in.Actors != nil {
in, out := &in.Actors, &out.Actors
*out = new(AggregatorReplicaStatus)
**out = **in
}
if in.Learners != nil {
in, out := &in.Learners, &out.Learners
*out = new(AggregatorReplicaStatus)
**out = **in
}
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AggregatorConfigStatus.
func (in *AggregatorConfigStatus) DeepCopy() *AggregatorConfigStatus {
if in == nil {
return nil
}
out := new(AggregatorConfigStatus)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AggregatorReplicaStatus) DeepCopyInto(out *AggregatorReplicaStatus) {
*out = *in
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AggregatorReplicaStatus.
func (in *AggregatorReplicaStatus) DeepCopy() *AggregatorReplicaStatus {
if in == nil {
return nil
}
out := new(AggregatorReplicaStatus)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AggregatorSpec) DeepCopyInto(out *AggregatorSpec) {
*out = *in
in.Template.DeepCopyInto(&out.Template)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AggregatorSpec.
func (in *AggregatorSpec) DeepCopy() *AggregatorSpec {
if in == nil {
return nil
}
out := new(AggregatorSpec)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *CollectorSpec) DeepCopyInto(out *CollectorSpec) {
*out = *in
in.Template.DeepCopyInto(&out.Template)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CollectorSpec.
func (in *CollectorSpec) DeepCopy() *CollectorSpec {
if in == nil {
return nil
}
out := new(CollectorSpec)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *CoordinatorSpec) DeepCopyInto(out *CoordinatorSpec) {
*out = *in
in.Template.DeepCopyInto(&out.Template)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CoordinatorSpec.
func (in *CoordinatorSpec) DeepCopy() *CoordinatorSpec {
if in == nil {
return nil
}
out := new(CoordinatorSpec)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DIJob) DeepCopyInto(out *DIJob) {
*out = *in
out.TypeMeta = in.TypeMeta
in.ObjectMeta.DeepCopyInto(&out.ObjectMeta)
in.Spec.DeepCopyInto(&out.Spec)
in.Status.DeepCopyInto(&out.Status)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DIJob.
func (in *DIJob) DeepCopy() *DIJob {
if in == nil {
return nil
}
out := new(DIJob)
in.DeepCopyInto(out)
return out
}
// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
func (in *DIJob) DeepCopyObject() runtime.Object {
if c := in.DeepCopy(); c != nil {
return c
}
return nil
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DIJobCondition) DeepCopyInto(out *DIJobCondition) {
*out = *in
in.LastUpdateTime.DeepCopyInto(&out.LastUpdateTime)
in.LastTransitionTime.DeepCopyInto(&out.LastTransitionTime)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DIJobCondition.
func (in *DIJobCondition) DeepCopy() *DIJobCondition {
if in == nil {
return nil
}
out := new(DIJobCondition)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DIJobList) DeepCopyInto(out *DIJobList) {
*out = *in
out.TypeMeta = in.TypeMeta
in.ListMeta.DeepCopyInto(&out.ListMeta)
if in.Items != nil {
in, out := &in.Items, &out.Items
*out = make([]DIJob, len(*in))
for i := range *in {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DIJobList.
func (in *DIJobList) DeepCopy() *DIJobList {
if in == nil {
return nil
}
out := new(DIJobList)
in.DeepCopyInto(out)
return out
}
// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
func (in *DIJobList) DeepCopyObject() runtime.Object {
if c := in.DeepCopy(); c != nil {
return c
}
return nil
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DIJobSpec) DeepCopyInto(out *DIJobSpec) {
*out = *in
if in.Volumes != nil {
in, out := &in.Volumes, &out.Volumes
*out = make([]v1.Volume, len(*in))
for i := range *in {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
in.Coordinator.DeepCopyInto(&out.Coordinator)
in.Collector.DeepCopyInto(&out.Collector)
in.Learner.DeepCopyInto(&out.Learner)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DIJobSpec.
func (in *DIJobSpec) DeepCopy() *DIJobSpec {
if in == nil {
return nil
}
out := new(DIJobSpec)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DIJobStatus) DeepCopyInto(out *DIJobStatus) {
*out = *in
if in.Conditions != nil {
in, out := &in.Conditions, &out.Conditions
*out = make([]DIJobCondition, len(*in))
for i := range *in {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
if in.ReplicaStatus != nil {
in, out := &in.ReplicaStatus, &out.ReplicaStatus
*out = make(map[ReplicaType]*ReplicaStatus, len(*in))
for key, val := range *in {
var outVal *ReplicaStatus
if val == nil {
(*out)[key] = nil
} else {
in, out := &val, &outVal
*out = new(ReplicaStatus)
**out = **in
}
(*out)[key] = outVal
}
}
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DIJobStatus.
func (in *DIJobStatus) DeepCopy() *DIJobStatus {
if in == nil {
return nil
}
out := new(DIJobStatus)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *LearnerSpec) DeepCopyInto(out *LearnerSpec) {
*out = *in
in.Template.DeepCopyInto(&out.Template)
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new LearnerSpec.
func (in *LearnerSpec) DeepCopy() *LearnerSpec {
if in == nil {
return nil
}
out := new(LearnerSpec)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ReplicaStatus) DeepCopyInto(out *ReplicaStatus) {
*out = *in
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReplicaStatus.
func (in *ReplicaStatus) DeepCopy() *ReplicaStatus {
if in == nil {
return nil
}
out := new(ReplicaStatus)
in.DeepCopyInto(out)
return out
}
package common
import (
"fmt"
corev1 "k8s.io/api/core/v1"
)
type GPUAllocator struct {
Nodes []*corev1.Node
policy Policy
}
func NewSimpleGPUAllocator(nodes []*corev1.Node) *GPUAllocator {
return &GPUAllocator{Nodes: nodes, policy: NewSimplePolicy()}
}
func (g *GPUAllocator) Allocate(gpus int) []int {
return g.policy.Allocate(g.Nodes, gpus)
}
func (g *GPUAllocator) NumGPUsOfMajorityNodeType() int {
return GetGPUsMajority(g.Nodes)
}
const (
SimpleGPUAllocPolicy = "simple"
)
type Policy interface {
Allocate(nodes []*corev1.Node, gpus int) []int
}
type SimplePolicy struct{}
func NewSimplePolicy() *SimplePolicy {
return &SimplePolicy{}
}
func (s *SimplePolicy) Allocate(nodes []*corev1.Node, gpus int) []int {
// gpusMajority is the node gpus with most frequent occurrence.
// maxGPUCount is the number of nodes with gpus equal to gpusMajority
gpusMajority := GetGPUsMajority(nodes)
if gpusMajority <= 0 {
return nil
}
perNodeGPUs := Max(gpusMajority, 1)
if gpus < perNodeGPUs {
return []int{gpus}
}
var result []int
nResults := gpus / perNodeGPUs
for i := 0; i < nResults; i++ {
result = append(result, perNodeGPUs)
}
remainGPUs := gpus - nResults*perNodeGPUs
if remainGPUs > 0 {
result = append(result, remainGPUs)
}
return result
}
func Max(x, y int) int {
if x < y {
return y
}
return x
}
func MaxInArray(v []int) (int, error) {
if len(v) == 0 {
return 0, fmt.Errorf("empty list")
}
max := v[0]
for _, i := range v {
if i > max {
max = i
}
}
return max, nil
}
func GetGPUsMajority(nodes []*corev1.Node) int {
var nodeGPUCounts []int
for _, node := range nodes {
allocGPUs := node.Status.Allocatable[corev1.ResourceName("nvidia.com/gpu")]
nodeGPUCounts = append(nodeGPUCounts, int(allocGPUs.Value()))
}
// gpusMajority is the number of gpus of majority nodes.
// majorityNodes is the number of nodes with gpus equal to gpusMajority
gpusMajority, _ := ValueOccursMostFrequentInList(nodeGPUCounts)
if gpusMajority == 0 {
max, _ := MaxInArray(nodeGPUCounts)
return max
}
return gpusMajority
}
// ValueOccursMostFrequentInList returns value that occurs most frequently in list,
// and the count of occurrences.
func ValueOccursMostFrequentInList(list []int) (int, int) {
if len(list) == 0 {
return 0, 0
}
// map the occurrence frequency of each value
maxCount := 0
maxCountValue := 0
valuesMap := make(map[int]int)
for _, v := range list {
if valuesMap[v] != 0 {
valuesMap[v]++
} else {
valuesMap[v] = 1
}
if maxCount < valuesMap[v] {
maxCount = valuesMap[v]
maxCountValue = v
} else if maxCount == valuesMap[v] && maxCountValue < v {
maxCountValue = v
}
}
return maxCountValue, maxCount
}
package common
import (
"testing"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"sigs.k8s.io/controller-runtime/pkg/envtest/printer"
)
func TestAllocators(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecsWithDefaultAndCustomReporters(t,
"Allocator Suite",
[]Reporter{printer.NewlineReporter{}})
}
var _ = Describe("Test SimpleGPUAllocator", func() {
It("ValueOccursMostFrequentInList function", func() {
testCases := map[string]struct {
list []int
expectedValue int
expectedCount int
}{
"Only one max value": {
[]int{1, 2, 3, 4, 5, 2, 2, 4, 6, 2, 3, 3, 1},
2, 4,
},
"Multi max value": {
[]int{1, 2, 3, 4, 5, 2, 2, 4, 6, 2, 3, 3, 1, 3},
3, 4,
},
"Multi max value second": {
[]int{1, 3, 3, 4, 5, 2, 2, 4, 6, 2, 3, 2, 1, 3},
3, 4,
},
}
for _, test := range testCases {
maxValue, maxCount := ValueOccursMostFrequentInList(test.list)
Expect(maxValue).To(Equal(test.expectedValue))
Expect(maxCount).To(Equal(test.expectedCount))
}
})
It("Allocate function", func() {
testCases := map[string]struct {
nodeGPUs map[int]int
gpus int
result []int
}{
"Only one max value with 12 gpus request": {
map[int]int{
8: 4,
10: 3,
6: 3,
},
12, []int{8, 4},
},
"Only one max value with 16 gpus request": {
map[int]int{
8: 4,
10: 3,
6: 3,
},
16, []int{8, 8},
},
"Multi max value with 16 gpus request": {
map[int]int{
8: 4,
10: 4,
6: 3,
},
16, []int{10, 6},
},
"Multi max value with 8 gpus request": {
map[int]int{
8: 4,
10: 4,
6: 3,
},
8, []int{8},
},
}
for _, test := range testCases {
var nodes []*corev1.Node
for nodeSpec, nodeGPUs := range test.nodeGPUs {
for i := 0; i < nodeGPUs; i++ {
nodes = append(nodes, newNode(nodeSpec))
}
}
alloc := NewSimpleGPUAllocator(nodes)
result := alloc.Allocate(test.gpus)
Expect(result).To(Equal(test.result))
}
})
})
func newNode(gpus int) *corev1.Node {
return &corev1.Node{
Status: corev1.NodeStatus{
Allocatable: corev1.ResourceList{
"nvidia.com/gpu": *resource.NewQuantity(int64(gpus), resource.DecimalExponent),
},
},
}
}
此差异已折叠。
# The following manifests contain a self-signed issuer CR and a certificate CR.
# More document can be found at https://docs.cert-manager.io
# WARNING: Targets CertManager v1.0. Check https://cert-manager.io/docs/installation/upgrading/ for breaking changes.
apiVersion: cert-manager.io/v1
kind: Issuer
metadata:
name: di-selfsigned-issuer
namespace: di-system
spec:
selfSigned: {}
---
apiVersion: cert-manager.io/v1
kind: Certificate
metadata:
name: di-serving-cert # this name should match the one appeared in kustomizeconfig.yaml
namespace: di-system
spec:
# $(SERVICE_NAME) and $(SERVICE_NAMESPACE) will be substituted by kustomize
dnsNames:
- $(SERVICE_NAME).$(SERVICE_NAMESPACE).svc
- $(SERVICE_NAME).$(SERVICE_NAMESPACE).svc.cluster.local
issuerRef:
kind: Issuer
name: di-selfsigned-issuer
secretName: di-webhook-server-cert # this secret will not be prefixed, since it's not managed by kustomize
\ No newline at end of file
resources:
- certificate.yaml
configurations:
- kustomizeconfig.yaml
\ No newline at end of file
# This configuration is for teaching kustomize how to update name ref and var substitution
nameReference:
- kind: Issuer
group: cert-manager.io
fieldSpecs:
- kind: Certificate
group: cert-manager.io
path: spec/issuerRef/name
varReference:
- kind: Certificate
group: cert-manager.io
path: spec/commonName
- kind: Certificate
group: cert-manager.io
path: spec/dnsNames
\ No newline at end of file
此差异已折叠。
# This kustomization.yaml is not intended to be run by itself,
# since it depends on service name and namespace that are out of this kustomize package.
# It should be run by config/default
resources:
- bases/diengine.opendilab.org_dijobs.yaml
- bases/diengine.opendilab.org_aggregatorconfigs.yaml
#+kubebuilder:scaffold:crdkustomizeresource
patchesStrategicMerge:
# [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix.
# patches here are for enabling the conversion webhook for each CRD
- patches/webhook_in_dijobs.yaml
- patches/webhook_in_aggregatorconfigs.yaml
#+kubebuilder:scaffold:crdkustomizewebhookpatch
# [CERTMANAGER] To enable webhook, uncomment all the sections with [CERTMANAGER] prefix.
# patches here are for enabling the CA injection for each CRD
- patches/cainjection_in_dijobs.yaml
- patches/cainjection_in_aggregatorconfigs.yaml
#+kubebuilder:scaffold:crdkustomizecainjectionpatch
# the following config is for teaching kustomize how to do kustomization for CRDs.
configurations:
- kustomizeconfig.yaml
# This file is for teaching kustomize how to substitute name and namespace reference in CRD
nameReference:
- kind: Service
version: v1
fieldSpecs:
- kind: CustomResourceDefinition
version: v1
group: apiextensions.k8s.io
path: spec/conversion/webhook/clientConfig/service/name
namespace:
- kind: CustomResourceDefinition
version: v1
group: apiextensions.k8s.io
path: spec/conversion/webhook/clientConfig/service/namespace
create: false
varReference:
- path: metadata/annotations
- kind: CustomResourceDefinition
group: apiextensions.k8s.io
path: spec/conversion/webhook/clientConfig/service/name
# The following patch adds a directive for certmanager to inject CA into the CRD
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
annotations:
cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME)
name: aggregatorconfigs.diengine.opendilab.org
# The following patch adds a directive for certmanager to inject CA into the CRD
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
annotations:
cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME)
name: dijobs.diengine.opendilab.org
# The following patch enables a conversion webhook for the CRD
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
name: aggregatorconfigs.diengine.opendilab.org
spec:
conversion:
strategy: Webhook
webhook:
clientConfig:
service:
namespace: $(SERVICE_NAMESPACE)
name: $(SERVICE_NAME)
path: /mutate-diengine-opendilab-org-v1alpha1-dijob
conversionReviewVersions:
- "v1"
- "v1beta1"
# The following patch enables a conversion webhook for the CRD
apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
name: dijobs.diengine.opendilab.org
spec:
conversion:
strategy: Webhook
webhook:
clientConfig:
service:
namespace: $(SERVICE_NAMESPACE)
name: $(SERVICE_NAME)
path: /mutate-diengine-opendilab-org-v1alpha1-dijob
conversionReviewVersions:
- "v1"
- "v1beta1"
# Adds namespace to all resources.
namespace: di-system
# Value of this field is prepended to the
# names of all resources, e.g. a deployment named
# "wordpress" becomes "alices-wordpress".
# Note that it should also match with the prefix (text before '-') of the namespace
# field above.
# namePrefix: di-
# Labels to add to all resources and selectors.
#commonLabels:
# someName: someValue
bases:
- ../crd
- ../rbac
- ../manager
# [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in
# crd/kustomization.yaml
- ../webhook
# [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER'. 'WEBHOOK' components are required.
- ../certmanager
# [PROMETHEUS] To enable prometheus monitor, uncomment all sections with 'PROMETHEUS'.
#- ../prometheus
patchesStrategicMerge:
# Protect the /metrics endpoint by putting it behind auth.
# If you want your controller-manager to expose the /metrics
# endpoint w/o any authn/z, please comment the following line.
- manager_auth_proxy_patch.yaml
# Mount the controller config file for loading manager configurations
# through a ComponentConfig type
#- manager_config_patch.yaml
# [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in
# crd/kustomization.yaml
- manager_webhook_patch.yaml
# [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER'.
# Uncomment 'CERTMANAGER' sections in crd/kustomization.yaml to enable the CA injection in the admission webhooks.
# 'CERTMANAGER' needs to be enabled to use ca injection
- webhookcainjection_patch.yaml
vars:
# [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER' prefix.
- name: CERTIFICATE_NAMESPACE # namespace of the certificate CR
objref:
kind: Certificate
group: cert-manager.io
version: v1
name: di-serving-cert # this name should match the one in certificate.yaml
fieldref:
fieldpath: metadata.namespace
- name: CERTIFICATE_NAME
objref:
kind: Certificate
group: cert-manager.io
version: v1
name: di-serving-cert # this name should match the one in certificate.yaml
- name: SERVICE_NAMESPACE # namespace of the service
objref:
kind: Service
version: v1
name: di-webhook-service
fieldref:
fieldpath: metadata.namespace
- name: SERVICE_NAME
objref:
kind: Service
version: v1
name: di-webhook-service
\ No newline at end of file
# This patch inject a sidecar container which is a HTTP proxy for the
# controller manager, it performs RBAC authorization against the Kubernetes API using SubjectAccessReviews.
apiVersion: apps/v1
kind: Deployment
metadata:
name: di-operator
namespace: di-system
spec:
template:
spec:
containers:
- name: manager
args:
- "--health-probe-bind-address=:8081"
- "--metrics-bind-address=:8080"
- "--leader-elect"
apiVersion: apps/v1
kind: Deployment
metadata:
name: di-operator
namespace: di-system
spec:
template:
spec:
containers:
- name: manager
args:
- "--config=controller_manager_config.yaml"
volumeMounts:
- name: manager-config
mountPath: /controller_manager_config.yaml
subPath: controller_manager_config.yaml
volumes:
- name: manager-config
configMap:
name: manager-config
apiVersion: apps/v1
kind: Deployment
metadata:
name: di-operator
namespace: di-system
spec:
template:
spec:
containers:
- name: manager
ports:
- containerPort: 9443
name: webhook-server
protocol: TCP
volumeMounts:
- mountPath: /tmp/k8s-webhook-server/serving-certs
name: cert
readOnly: true
volumes:
- name: cert
secret:
defaultMode: 420
secretName: di-webhook-server-cert
# This patch add annotation to admission webhook config and
# the variables $(CERTIFICATE_NAMESPACE) and $(CERTIFICATE_NAME) will be substituted by kustomize.
apiVersion: admissionregistration.k8s.io/v1
kind: MutatingWebhookConfiguration
metadata:
name: di-mutating-webhook-configuration
annotations:
cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME)
---
apiVersion: admissionregistration.k8s.io/v1
kind: ValidatingWebhookConfiguration
metadata:
name: di-validating-webhook-configuration
annotations:
cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME)
此差异已折叠。
apiVersion: v1
kind: Namespace
metadata:
labels:
control-plane: di-operator
name: di-system
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: di-operator
namespace: di-system
labels:
control-plane: di-operator
spec:
selector:
matchLabels:
control-plane: di-operator
replicas: 1
template:
metadata:
labels:
control-plane: di-operator
spec:
containers:
- command:
- /di-operator
args:
- "--server-address=http://di-server.di-system:8080"
image: diorchestrator/di-operator:v0.1.0
imagePullPolicy: Always
name: manager
securityContext:
allowPrivilegeEscalation: false
livenessProbe:
httpGet:
path: /healthz
port: 8081
initialDelaySeconds: 15
periodSeconds: 20
readinessProbe:
httpGet:
path: /readyz
port: 8081
initialDelaySeconds: 5
periodSeconds: 10
resources:
limits:
cpu: 100m
memory: 500Mi
requests:
cpu: 100m
memory: 500Mi
terminationGracePeriodSeconds: 10
apiVersion: apps/v1
kind: Deployment
metadata:
name: di-server
namespace: di-system
labels:
control-plane: di-server
spec:
selector:
matchLabels:
control-plane: di-server
replicas: 1
template:
metadata:
labels:
control-plane: di-server
spec:
containers:
- command:
- /di-server
args:
- "--server-bind-address=:8080"
- "--leader-elect"
- "--lease-lock-namespace=di-system"
- "--lease-lock-name=di-server"
image: diorchestrator/di-server:v0.1.0
imagePullPolicy: Always
name: server
securityContext:
allowPrivilegeEscalation: false
livenessProbe:
httpGet:
path: /healthz
port: 8080
initialDelaySeconds: 15
periodSeconds: 20
resources:
limits:
cpu: 100m
memory: 500Mi
requests:
cpu: 100m
memory: 500Mi
terminationGracePeriodSeconds: 10
resources:
- di_operator.yaml
- di_server.yaml
apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
images:
- name: diorchestrator/di-operator
newName: diorchestrator/di-operator
newTag: v0.1.0
- name: diorchestrator/di-server
newName: diorchestrator/di-server
newTag: v0.1.0
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: di-metrics-reader
rules:
- nonResourceURLs: ["/metrics"]
verbs: ["get"]
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: di-proxy-role
rules:
- apiGroups: ["authentication.k8s.io"]
resources:
- tokenreviews
verbs: ["create"]
- apiGroups: ["authorization.k8s.io"]
resources:
- subjectaccessreviews
verbs: ["create"]
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: di-proxy-rolebinding
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: di-proxy-role
subjects:
- kind: ServiceAccount
name: default
namespace: di-system
apiVersion: v1
kind: Service
metadata:
labels:
control-plane: di-operator
name: di-operator-metrics-service
namespace: di-system
spec:
ports:
- name: https
port: 8443
targetPort: 8080
selector:
control-plane: di-operator
---
apiVersion: v1
kind: Service
metadata:
name: di-server
namespace: di-system
spec:
selector:
control-plane: di-server
ports:
- protocol: TCP
port: 8080
targetPort: 8080
---
apiVersion: v1
kind: Service
metadata:
name: di-server-nodeport
namespace: di-system
spec:
selector:
control-plane: di-server
type: NodePort
ports:
- protocol: TCP
port: 8080
targetPort: 8080
nodePort: 32270
\ No newline at end of file
resources:
- role.yaml
- role_binding.yaml
- leader_election_role.yaml
- leader_election_role_binding.yaml
# Comment the following 4 lines if you want to disable
# the auth proxy (https://github.com/brancz/kube-rbac-proxy)
# which protects your /metrics endpoint.
- auth_proxy_service.yaml
- auth_proxy_role.yaml
- auth_proxy_role_binding.yaml
- auth_proxy_client_clusterrole.yaml
- di_server_service.yaml
\ No newline at end of file
# permissions to do leader election.
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: di-leader-election-role
namespace: di-system
rules:
- apiGroups:
- ""
- coordination.k8s.io
resources:
- configmaps
- leases
verbs:
- get
- list
- watch
- create
- update
- patch
- delete
- apiGroups:
- ""
resources:
- events
verbs:
- create
- patch
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: di-leader-election-rolebinding
namespace: di-system
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: di-leader-election-role
subjects:
- kind: ServiceAccount
name: default
namespace: di-system
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
creationTimestamp: null
name: di-operator-cluster-role
rules:
- apiGroups:
- ""
resources:
- events
- pods
- services
verbs:
- create
- delete
- get
- list
- patch
- update
- watch
- apiGroups:
- ""
resources:
- namespaces
- nodes
verbs:
- get
- list
- apiGroups:
- diengine.opendilab.org
resources:
- aggregatorconfigs
- dijobs
verbs:
- create
- delete
- get
- list
- patch
- update
- watch
- apiGroups:
- diengine.opendilab.org
resources:
- aggregatorconfigs/finalizers
- dijobs/finalizers
verbs:
- update
- apiGroups:
- diengine.opendilab.org
resources:
- aggregatorconfigs/status
- dijobs/status
verbs:
- get
- patch
- update
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: di-operator-cluster-rolebinding
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: di-operator-cluster-role
subjects:
- kind: ServiceAccount
name: default
namespace: di-system
apiVersion: diengine.opendilab.org/v1alpha1
kind: AggregatorConfig
metadata:
name: aggregator-config
namespace: di-system
spec:
aggregator:
template:
spec:
containers:
- name: di-container
image: diorchestrator/di-mock:v0.0.5
imagePullPolicy: IfNotPresent
command: ["/bin/bash", "-c",]
args: ["until ping -c 1 $HOSTNAME.default ; do sleep 1 ; done ; sleep 5; python3 -u main.py aggregator -sl $HOSTNAME.default -sp $AGGREGATOR_PORT -sl $HOSTNAME.default -ml $HOSTNAME.default -mp 81"]
ports:
- name: di-port
containerPort: 80
apiVersion: diengine.opendilab.org/v1alpha1
kind: DIJob
metadata:
name: dijob-example
spec:
group: xxx
priorityClassName: ""
cleanPodPolicy: "Running"
coordinator:
template:
spec:
containers:
- name: di-container
image: diorchestrator/di-mock:v0.0.5
imagePullPolicy: Always
command: ["/bin/bash", "-c",]
args: ["python3 -u main.py coordinator -l $HOSTNAME -p $COORDINATOR_PORT"]
# args: ["sleep 3600"]
collector:
template:
spec:
containers:
- name: di-container
image: diorchestrator/di-mock:v0.0.5
imagePullPolicy: Always
command: ["/bin/bash", "-c",]
args: ["until ping -c 1 $HOSTNAME.default ; do sleep 1 ; done ; sleep 10; python3 -u main.py collector -l $HOSTNAME.default -p $COLLECTOR_PORT"]
ports:
- name: di-port
containerPort: 80
learner:
template:
spec:
containers:
- name: di-container
image: diorchestrator/di-mock:v0.0.5
imagePullPolicy: Always
command: ["/bin/bash", "-c",]
args: ["until ping -c 1 $HOSTNAME.default ; do sleep 1 ; done ; sleep 10; python3 -u main.py learner -l $HOSTNAME.default -p $LEARNER_PORT"]
ports:
- name: di-port
containerPort: 80
FROM python:3
WORKDIR /app
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY ./data ./data
COPY ./interaction ./interaction
COPY ./utils ./utils
COPY ./worker ./worker
COPY ./main.py ./main.py
# docker build -t diorchestrator/di-mock:v0.0.1 .
\ No newline at end of file
import os
def clear(filepath):
files = os.listdir(filepath)
for fd in files:
cur_path = os.path.join(filepath, fd)
if os.path.isdir(cur_path):
if fd == "__pycache__":
# print("rm %s -rf" % cur_path)
os.system("rm %s -rf" % cur_path)
else:
clear(cur_path)
if __name__ == "__main__":
clear("./")
\ No newline at end of file
from .replay_buffer import ReplayBuffer
\ No newline at end of file
from queue import Queue
class ReplayBuffer:
def __init__(self, cfg: dict):
self._meta_buffer = []
def push_data(self, data) -> None:
self._meta_buffer.append(data)
def sample(self, batch_size: int):
data = None
if len(self._meta_buffer) >= batch_size:
data = self._meta_buffer[:batch_size]
self._meta_buffer = self._meta_buffer[batch_size:]
return data
def run(self):
pass
def close(self):
pass
\ No newline at end of file
from .master import *
from .slave import *
from .app import CommonErrorCode, success_response, failure_response, get_values_from_response, flask_response, \
ResponsibleException, responsible
from .common import random_token, translate_dict_func, ControllableService, ControllableContext, default_func
from .network import get_host_ip, get_http_engine_class, HttpEngine, split_http_address
from .threading import DblEvent
import json
from enum import IntEnum, unique
from functools import wraps
from typing import Optional, Any, Mapping, Tuple, Union, Iterable, Type, Callable
import flask
import requests
from flask import jsonify
@unique
class CommonErrorCode(IntEnum):
SUCCESS = 0
COMMON_FAILURE = 1
def flask_response(
success: bool, data: Optional[Mapping[str, Any]] = None, message: Optional[str] = None, code: Optional[int] = None
):
return jsonify(
{
'success': success,
'code': 0 if success else (code or CommonErrorCode.COMMON_FAILURE),
'message': (message or 'Success.') if success else (message or 'Failed.'),
'data': data,
}
)
def success_response(data: Optional[Mapping[str, Any]] = None, message: Optional[str] = None):
return flask_response(
success=True,
code=CommonErrorCode.SUCCESS,
message=message,
data=data,
)
def failure_response(
code: Optional[int] = None, message: Optional[str] = None, data: Optional[Mapping[str, Any]] = None
):
return flask_response(
success=False,
code=code or CommonErrorCode.COMMON_FAILURE,
message=message,
data=data,
)
_RESPONSE_VALUE_TYPE = Tuple[int, bool, int, str, Mapping[str, Any]]
def get_values_from_response(response: Union[requests.Response, flask.Response]) -> _RESPONSE_VALUE_TYPE:
status_code = response.status_code
_content = response.content if hasattr(response, 'content') else response.data
_json = json.loads(_content.decode())
success, code, message, data = _json['success'], _json['code'], _json.get('message', ''), _json.get('data', {})
return status_code, success, code, message, data
class ResponsibleException(Exception):
def __init__(
self,
code: int = CommonErrorCode.COMMON_FAILURE,
message: Optional[str] = None,
data: Optional[Mapping[str, Any]] = None,
status_code: int = 400
):
Exception.__init__(self, message)
self.__code = code
self.__message = message
self.__data = data or {}
self.__status_code = status_code
def get_response(self):
return failure_response(self.__code, self.__message, self.__data), self.__status_code
def responsible(classes: Iterable[Type[ResponsibleException]] = None):
if classes is None:
classes = (ResponsibleException, )
def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def _func(*args, **kwargs):
try:
ret = func(*args, **kwargs)
except tuple(classes) as err:
return err.get_response()
else:
return ret
return _func
return _decorator
import random
import string
from abc import ABCMeta, abstractmethod
from typing import Optional, Callable, Mapping, Any, Dict
_LENGTH_OF_RANDOM_TOKEN = 64
def random_token(length: Optional[int] = None) -> str:
return ''.join([random.choice(string.hexdigits) for _ in range(length or _LENGTH_OF_RANDOM_TOKEN)])
class ControllableContext(metaclass=ABCMeta):
@abstractmethod
def start(self):
raise NotImplementedError # pragma: no cover
@abstractmethod
def close(self):
raise NotImplementedError # pragma: no cover
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class ControllableService(ControllableContext, metaclass=ABCMeta):
@abstractmethod
def start(self):
raise NotImplementedError # pragma: no cover
@abstractmethod
def shutdown(self):
raise NotImplementedError # pragma: no cover
@abstractmethod
def join(self):
raise NotImplementedError # pragma: no cover
def close(self):
self.shutdown()
self.join()
def translate_dict_func(d: Mapping[str, Callable[..., Any]]) -> Callable[..., Dict[str, Any]]:
def _func(*args, **kwargs) -> Dict[str, Any]:
return {k: f(*args, **kwargs) for k, f in d.items()}
return _func
def default_func(return_value=None) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
# noinspection PyUnusedLocal
def _func(*args, **kwargs):
return return_value
return func or _func
return _decorator
import json
import socket
from typing import Optional, Any, Mapping, Callable, Type, Tuple
import requests
from urlobject import URLObject
from urlobject.path import URLPath
from .common import translate_dict_func
def get_host_ip() -> Optional[str]:
s = None
try:
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# s.connect(('8.8.8.8', 80))
# ip = s.getsockname()[0]
myname = socket.getfqdn(socket.gethostname())
ip = socket.gethostbyname(myname)
finally:
if s is not None:
s.close()
return ip
_DEFAULT_HTTP_PORT = 80
_DEFAULT_HTTPS_PORT = 443
def split_http_address(address: str, default_port: Optional[int] = None) -> Tuple[str, int, bool, str]:
_url = URLObject(address)
_host = _url.hostname
_https = (_url.scheme.lower()) == 'https'
_port = _url.port or default_port or (_DEFAULT_HTTPS_PORT if _https else _DEFAULT_HTTP_PORT)
_path = str(_url.path) or ''
return _host, _port, _https, _path
class HttpEngine:
def __init__(self, host: str, port: int, https: bool = False, path: str = None):
self.__base_url = URLObject().with_scheme('https' if https else 'http') \
.with_hostname(host).with_port(port).add_path(path or '')
self.__session = requests.session()
# noinspection PyMethodMayBeStatic
def _data_process(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
return data or {}
# noinspection PyMethodMayBeStatic
def _base_headers(self) -> Mapping[str, None]:
return {}
def get_url(self, path: str = None):
original_segments = self.__base_url.path.segments
path_segments = URLPath().add(path or '').segments
return str(self.__base_url.with_path(URLPath.join_segments(original_segments + path_segments)))
def request(
self,
method: str,
path: str,
data: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, Any]] = None,
params: Optional[Mapping[str, Any]] = None,
raise_for_status: bool = True,
) -> requests.Response:
_headers = dict(self._base_headers())
_headers.update(headers or {})
response = self.__session.request(
method=method,
url=self.get_url(path),
data=json.dumps(self._data_process(data) or {}),
headers=_headers or {},
params=params or {},
)
if raise_for_status:
response.raise_for_status()
return response
def get_http_engine_class(
headers: Mapping[str, Callable[..., Any]],
data_processor: Optional[Callable[[Mapping[str, Any]], Mapping[str, Any]]] = None
) -> Callable[..., Type[HttpEngine]]:
def _func(*args, **kwargs) -> Type[HttpEngine]:
class _HttpEngine(HttpEngine):
def _data_process(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
return (data_processor or (lambda d: d or {}))(data or {})
def _base_headers(self) -> Mapping[str, None]:
return translate_dict_func(headers)(*args, **kwargs)
return _HttpEngine
return _func
from threading import Event, Lock
from typing import Optional
class DblEvent:
def __init__(self, opened: bool = False):
self.__open_event = Event()
self.__close_event = Event()
self.__lock = Lock()
if opened:
self.__open_event.set()
else:
self.__close_event.set()
def wait_for_open(self, timeout: Optional[float] = None):
self.__open_event.wait(timeout=timeout)
def wait_for_close(self, timeout: Optional[float] = None):
self.__close_event.wait(timeout=timeout)
def open(self):
with self.__lock:
self.__open_event.set()
self.__close_event.clear()
def close(self):
with self.__lock:
self.__close_event.set()
self.__open_event.clear()
def is_open(self) -> bool:
with self.__lock:
return self.__open_event.is_set()
def is_close(self) -> bool:
with self.__lock:
return self.__close_event.is_set()
from .base import MIN_HEARTBEAT_CHECK_SPAN, MIN_HEARTBEAT_SPAN, DEFAULT_MASTER_PORT, DEFAULT_SLAVE_PORT, \
DEFAULT_CHANNEL, DEFAULT_HEARTBEAT_CHECK_SPAN, DEFAULT_HEARTBEAT_TOLERANCE, DEFAULT_HEARTBEAT_SPAN, LOCAL_HOST, \
GLOBAL_HOST
# System configs
GLOBAL_HOST = '0.0.0.0'
LOCAL_HOST = '127.0.0.1'
# Slave configs
MIN_HEARTBEAT_SPAN = 0.2
DEFAULT_HEARTBEAT_SPAN = 3.0
DEFAULT_HEARTBEAT_TOLERANCE = 15.0
DEFAULT_SLAVE_PORT = 7236
# Master configs
MIN_HEARTBEAT_CHECK_SPAN = 0.1
DEFAULT_HEARTBEAT_CHECK_SPAN = 1.0
DEFAULT_MASTER_PORT = 7235
# Two-side configs
DEFAULT_CHANNEL = 0
from .error_code import MasterErrorCode
from .master import Master
from typing import Callable, Mapping, Any, Optional
from requests import RequestException
_BEFORE_HOOK_TYPE = Callable[..., Mapping[str, Any]]
_AFTER_HOOK_TYPE = Callable[[int, bool, int, Optional[str], Optional[Mapping[str, Any]]], Any]
_ERROR_HOOK_TYPE = Callable[[
RequestException,
], Any]
from abc import ABCMeta, abstractmethod
from functools import wraps
from threading import Lock
from typing import Optional, Any, Mapping, Type, Callable
from uuid import uuid4, UUID
import requests
from requests.exceptions import RequestException
from .base import _BEFORE_HOOK_TYPE, _AFTER_HOOK_TYPE, _ERROR_HOOK_TYPE
from .task import Task, _task_complete, _task_fail
from ..base import random_token, ControllableContext, get_http_engine_class, get_values_from_response
from ..config import DEFAULT_CHANNEL, DEFAULT_SLAVE_PORT
_COMPLETE_TRIGGER_NAME = '__TASK_COMPLETE__'
_FAIL_TRIGGER_NAME = '__TASK_FAIL__'
class _ISlaveConnection(ControllableContext, metaclass=ABCMeta):
@abstractmethod
def connect(self):
raise NotImplementedError # pragma: no cover
@abstractmethod
def disconnect(self):
raise NotImplementedError # pragma: no cover
@abstractmethod
def new_task(self, data: Optional[Mapping[str, Any]] = None):
raise NotImplementedError # pragma: no cover
def start(self):
self.connect()
def close(self):
self.disconnect()
class SlaveConnection(_ISlaveConnection, metaclass=ABCMeta):
def __init__(
self,
host: str,
port: Optional[int] = None,
https: bool = False,
channel: Optional[int] = None,
my_address: Optional[str] = None,
token: Optional[str] = None
):
# meta info part
self.__channel = channel or DEFAULT_CHANNEL
self.__my_address = my_address
self.__token = token or random_token()
# request part
self.__http_engine = get_http_engine_class(
headers={
'Channel': lambda: str(self.__channel),
'Token': lambda: self.__token,
}
)()(host, port or DEFAULT_SLAVE_PORT, https)
# threading part
self.__lock = Lock()
self.__is_connected = False
# task part
self.__tasks = {}
self.__init_triggers()
def __request(self, method: str, path: str, data: Optional[Mapping[str, Any]] = None) -> requests.Response:
return self.__http_engine.request(method, path, data)
@property
def is_connected(self) -> bool:
with self.__lock:
return self.__is_connected
def _before_connect(self) -> Mapping[str, Any]:
pass # pragma: no cover
def _after_connect(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
pass # pragma: no cover
def _error_connect(self, error: RequestException) -> Any:
raise error # pragma: no cover
def __connect(self):
try:
response = self.__request(
'POST', '/connect', {
'master': {
'address': self.__my_address,
},
'data': (self._before_connect() or {})
}
)
except RequestException as err:
return self._error_connect(err)
else:
self.__is_connected = True
return self._after_connect(*get_values_from_response(response))
def connect(self):
with self.__lock:
return self.__connect()
def _before_disconnect(self) -> Mapping[str, Any]:
pass # pragma: no cover
def _after_disconnect(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
pass # pragma: no cover
def _error_disconnect(self, error: RequestException) -> Any:
raise error # pragma: no cover
def __disconnect(self):
try:
response = self.__request('DELETE', '/disconnect', {
'data': self._before_disconnect() or {},
})
except RequestException as err:
print(err)
return self._error_disconnect(err)
else:
self.__is_connected = False
return self._after_disconnect(*get_values_from_response(response))
def disconnect(self):
with self.__lock:
return self.__disconnect()
def _after_update_learner(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
return data
def _error_update_learners(self, error: RequestException) -> Any:
raise error
def __update_learners(self, learners):
try:
response = self.__request('POST', '/learners', learners)
except RequestException as err:
return self._error_update_learners(err)
else:
return self._after_update_learner(*get_values_from_response(response))
def update_learners(self, learners):
with self.__lock:
return self.__update_learners(learners)
def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
return data # pragma: no cover
def _after_new_task(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
pass # pragma: no cover
def _error_new_task(self, error: RequestException) -> Any:
raise error # pragma: no cover
def new_task(self, data: Optional[Mapping[str, Any]] = None) -> Task:
with self.__lock:
_uuid = uuid4()
_task = Task(
http_engine=self.__http_engine,
data=data,
task_id=_uuid,
before_task_start=self._before_new_task,
after_task_start=self._after_new_task,
error_task_start=self._error_new_task,
)
self.__tasks[_uuid] = _task
return _task
def __task_complete(self, task_id: UUID, task_result: Mapping[str, Any]):
_task = self.__tasks[task_id]
_task_complete(_task, task_result)
del self.__tasks[task_id]
def __task_fail(self, task_id: UUID, task_result: Mapping[str, Any]):
_task = self.__tasks[task_id]
_task_fail(_task, task_result)
del self.__tasks[task_id]
def __task_complete_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
with self.__lock:
if task_id in self.__tasks.keys():
return self.__task_complete(task_id, task_result)
else:
raise KeyError("Task {uuid} not found in this connection.".format(uuid=repr(str(task_id))))
def __task_fail_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
with self.__lock:
if task_id in self.__tasks.keys():
return self.__task_fail(task_id, task_result)
else:
raise KeyError("Task {uuid} not found in this connection.".format(uuid=repr(str(task_id))))
def __init_triggers(self):
setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger)
setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger)
def _connection_task_complete(connection: SlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
return getattr(connection, _COMPLETE_TRIGGER_NAME)(task_id, task_result)
def _connection_task_fail(connection: SlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
return getattr(connection, _FAIL_TRIGGER_NAME)(task_id, task_result)
class SlaveConnectionProxy(_ISlaveConnection):
def __init__(
self,
connection: SlaveConnection,
after_connect: Optional[Callable] = None,
after_disconnect: Optional[Callable] = None
):
self.__connection = connection
self.__lock = Lock()
self.__after_connect = after_connect
self.__after_disconnect = after_disconnect
self.__init_triggers()
@property
def is_connected(self) -> bool:
with self.__lock:
return self.__connection.is_connected
def connect(self):
with self.__lock:
result = self.__connection.connect()
if self.__after_connect is not None:
self.__after_connect(connection=self)
return result
def disconnect(self):
with self.__lock:
result = self.__connection.disconnect()
if self.__after_disconnect is not None:
self.__after_disconnect(connection=self)
return result
def new_task(self, data: Optional[Mapping[str, Any]] = None):
with self.__lock:
return self.__connection.new_task(data)
def update_learners(self, learners):
with self.__lock:
return self.__connection.update_learners(learners)
def __task_complete_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
with self.__lock:
return _connection_task_complete(self.__connection, task_id, task_result)
def __task_fail_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
with self.__lock:
return _connection_task_fail(self.__connection, task_id, task_result)
def __init_triggers(self):
setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger)
setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger)
def _proxy_task_complete(proxy: SlaveConnectionProxy, task_id: UUID, task_result: Mapping[str, Any]):
return getattr(proxy, _COMPLETE_TRIGGER_NAME)(task_id, task_result)
def _proxy_task_fail(proxy: SlaveConnectionProxy, task_id: UUID, task_result: Mapping[str, Any]):
return getattr(proxy, _FAIL_TRIGGER_NAME)(task_id, task_result)
def _slave_task_complete(connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
if isinstance(connection, SlaveConnection):
return _connection_task_complete(connection, task_id, task_result)
elif isinstance(connection, SlaveConnectionProxy):
return _proxy_task_complete(connection, task_id, task_result)
else:
raise TypeError(
"{expect1} or {expect2} expected, but {actual} found.".format(
expect1=SlaveConnection.__name__,
expect2=SlaveConnectionProxy.__name__,
actual=type(connection).__name__,
)
)
def _slave_task_fail(connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
if isinstance(connection, SlaveConnection):
return _connection_task_fail(connection, task_id, task_result)
elif isinstance(connection, SlaveConnectionProxy):
return _proxy_task_fail(connection, task_id, task_result)
else:
raise TypeError(
"{expect1} or {expect2} expected, but {actual} found.".format(
expect1=SlaveConnection.__name__,
expect2=SlaveConnectionProxy.__name__,
actual=type(connection).__name__,
)
)
def _default_wrap(func: Callable) -> Callable:
@wraps(func)
def _new_func(*args, **kwargs):
if func:
return func(*args, **kwargs)
else:
return None
return _new_func
def _get_connection_class(
before_new_task: Optional[_BEFORE_HOOK_TYPE] = None,
after_new_task: Optional[_AFTER_HOOK_TYPE] = None,
error_new_task: Optional[_ERROR_HOOK_TYPE] = None,
before_connect: Optional[_BEFORE_HOOK_TYPE] = None,
after_connect: Optional[_AFTER_HOOK_TYPE] = None,
error_connect: Optional[_ERROR_HOOK_TYPE] = None,
before_disconnect: Optional[_BEFORE_HOOK_TYPE] = None,
after_disconnect: Optional[_AFTER_HOOK_TYPE] = None,
error_disconnect: Optional[_ERROR_HOOK_TYPE] = None,
) -> Type[SlaveConnection]:
class _Connection(SlaveConnection):
def _before_connect(self) -> Mapping[str, Any]:
return _default_wrap(before_connect)() or {}
def _after_connect(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str,
Any]]
) -> Any:
return _default_wrap(after_connect)(status_code, success, code, message, data)
def _error_connect(self, error: RequestException) -> Any:
return _default_wrap(error_connect)(error)
def _before_disconnect(self) -> Mapping[str, Any]:
return _default_wrap(before_disconnect)() or {}
def _after_disconnect(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str,
Any]]
) -> Any:
return _default_wrap(after_disconnect)(status_code, success, code, message, data)
def _error_disconnect(self, error: RequestException) -> Any:
return _default_wrap(error_disconnect)(error)
def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
return _default_wrap(before_new_task)(data) or {}
def _after_new_task(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str,
Any]]
) -> Any:
return _default_wrap(after_new_task)(status_code, success, code, message, data)
def _error_new_task(self, error: RequestException) -> Any:
return _default_wrap(error_new_task)(error)
return _Connection
class ServerConnection:
def __init__(
self,
host: str,
port: Optional[int] = None,
api_version: str = "v1alpha1",
https: bool = False,
namespace: str = None,
name: str = None,
):
# request part
self.__http_engine = get_http_engine_class(headers={})()(host, port, https)
self.__api_version = api_version
self.__namespace = namespace
self.__my_name = name
@property
def api_version(self):
return self.__api_version
def __prefix_with_api_version(self, path):
return self.__api_version + path
def get_replicas(self, name: str = None):
try:
if name is None:
params = {"namespace": self.__namespace, "coordinator": self.__my_name}
else:
params = {"namespace": self.__namespace, "name": name}
response = self.__http_engine.request('GET', self.__prefix_with_api_version('/replicas'), params=params)
except RequestException as err:
return self._error_request(err)
else:
return self._after_request(*get_values_from_response(response))
def post_replicas(self, data):
try:
data.update(
{
"namespace": self.__namespace,
"coordinator": self.__my_name
}
)
response = self.__http_engine.request('POST', self.__prefix_with_api_version('/replicas'), data=data)
except RequestException as err:
return self._error_request(err)
else:
return self._after_request(*get_values_from_response(response))
def post_replicas_failed(self, collectors=[], learners=[]):
try:
data = {
"namespace": self.__namespace,
"coordinator": self.__my_name,
"collectors": collectors,
"learners": learners,
}
response = self.__http_engine.request('POST', self.__prefix_with_api_version('/replicas/failed'), data=data)
except RequestException as err:
return self._error_request(err)
else:
return self._after_request(*get_values_from_response(response))
def delete_replicas(self, n_collectors=0, n_learners=0):
try:
data = {
"namespace": self.__namespace,
"coordinator": self.__my_name,
"collectors": {
"replicas": n_collectors,
},
"learners":{
"replicas": n_learners,
}
}
response = self.__http_engine.request('DELETE', self.__prefix_with_api_version('/replicas'), data=data)
except RequestException as err:
return self._error_request(err)
else:
return self._after_request(*get_values_from_response(response))
def _after_request(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
return success, code, message, data
def _error_request(self, error: RequestException) -> Any:
# raise error
pass
\ No newline at end of file
from enum import unique, IntEnum
@unique
class MasterErrorCode(IntEnum):
SUCCESS = 0
SYSTEM_SHUTTING_DOWN = 101
CHANNEL_NOT_GIVEN = 201
CHANNEL_INVALID = 202
MASTER_TOKEN_NOT_GIVEN = 301
MASTER_TOKEN_INVALID = 302
SELF_TOKEN_NOT_GIVEN = 401
SELF_TOKEN_INVALID = 402
SLAVE_TOKEN_NOT_GIVEN = 501
SLAVE_TOKEN_INVALID = 502
TASK_DATA_INVALID = 601
import json
import time
from functools import wraps, partial
from queue import Queue, Empty
from threading import Lock, Thread, Event
from typing import Optional, Any, Mapping, Type, Callable
from uuid import UUID
import requests
from flask import Flask, request
from requests.exceptions import RequestException
from urlobject import URLObject
from .connection import SlaveConnectionProxy, SlaveConnection, _ISlaveConnection, _get_connection_class, \
_slave_task_complete, _slave_task_fail, ServerConnection
from .error_code import MasterErrorCode
from .task import TaskResultType
from ..base import random_token, ControllableService, failure_response, success_response, get_host_ip, \
get_http_engine_class
from ..config import GLOBAL_HOST, DEFAULT_MASTER_PORT, DEFAULT_CHANNEL, MIN_HEARTBEAT_SPAN, DEFAULT_HEARTBEAT_SPAN, \
DEFAULT_HEARTBEAT_TOLERANCE, MIN_HEARTBEAT_CHECK_SPAN, DEFAULT_HEARTBEAT_CHECK_SPAN
class Master(ControllableService):
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
heartbeat_span: Optional[float] = None,
heartbeat_tolerance: Optional[float] = None,
heartbeat_check_span: Optional[float] = None,
channel: Optional[int] = None,
my_address: Optional[str] = None
):
# server part
self.__host = host or GLOBAL_HOST
self.__port = port or DEFAULT_MASTER_PORT
self.__flask_app_value = None
self.__run_app_thread = Thread(target=self.__run_app)
# heartbeat part
self.__heartbeat_span = max(heartbeat_span or DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN)
self.__heartbeat_tolerance = max(heartbeat_tolerance or DEFAULT_HEARTBEAT_TOLERANCE, MIN_HEARTBEAT_SPAN)
self.__heartbeat_check_span = max(
heartbeat_check_span or DEFAULT_HEARTBEAT_CHECK_SPAN, MIN_HEARTBEAT_CHECK_SPAN
)
self.__heartbeat_check_thread = Thread(target=self.__heartbeat_check)
# self-connection part
self.__self_http_engine = get_http_engine_class(headers={
'Token': lambda: self.__self_token,
# })()('localhost', self.__port, False)
})()(self.__host, self.__port, False)
self.__self_token = random_token()
# slave-connection part
self.__channel = channel or DEFAULT_CHANNEL
self.__my_address = my_address or str(
URLObject().with_scheme('http').with_hostname(get_host_ip()).with_port(self.__port)
# URLObject().with_scheme('http').with_hostname(self.__host).with_port(self.__port)
)
# slaves part
self.__slaves = {} # name --> (token, slave_connection)
self.__token_slaves = {} # token --> (name, slave_connection)
self.__slave_last_heartbeat = {} # name --> last_heartbeat
self.__slave_lock = Lock()
# task part
self.__task_result_queue = Queue()
self.__task_result_process_thread = Thread(target=self.__task_result_process)
# global part
self.__shutdown_event = Event()
self.__lock = Lock()
# k8s: di-server
self.__server_http_engine = None
# slave connection
def __connection_open(self, name: str, token: str, connection: SlaveConnectionProxy):
with self.__slave_lock:
self.__slaves[name] = (token, connection)
self.__token_slaves[token] = (name, connection)
self.__slave_last_heartbeat[name] = time.time()
# noinspection PyUnusedLocal
def __connection_close(self, name: str, connection: Optional[SlaveConnectionProxy] = None):
with self.__slave_lock:
token, _conn = self.__slaves[name]
connection = connection or _conn
del self.__slaves[name]
del self.__token_slaves[token]
del self.__slave_last_heartbeat[name]
# server part
def __generate_app(self):
app = Flask(__name__)
# self apis
app.route('/ping', methods=['GET'])(self.__check_self_request(self.__self_ping))
app.route('/shutdown', methods=['DELETE'])(self.__check_self_request(self.__self_shutdown))
# slave apis
app.route('/slave/heartbeat', methods=['GET'])(self.__check_slave_request(self.__heartbeat))
app.route(
'/slave/task/complete', methods=['PUT']
)(self.__check_slave_request(self.__check_task_info(self.__task_complete)))
app.route(
'/slave/task/fail', methods=['PUT']
)(self.__check_slave_request(self.__check_task_info(self.__task_fail)))
return app
def __flask_app(self) -> Flask:
return self.__flask_app_value or self.__generate_app()
def __run_app(self):
while True:
try:
self.__flask_app().run(
host=self.__host,
port=self.__port,
)
except:
print("failed to run flask app on {}:{}..".format(self.__host, self.__port))
else:
break
# both method checkers
def __check_shutdown(self, func: Callable[[], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
if self.__shutdown_event.is_set():
return failure_response(
code=MasterErrorCode.SYSTEM_SHUTTING_DOWN, message='System has already been shutting down.'
), 401
else:
return func()
return _func
# server method checkers (self)
# noinspection DuplicatedCode
def __check_self_request(self, func: Callable[[], Any]) -> Callable[[], Any]:
return self.__check_shutdown(self.__check_master_token(func))
def __check_master_token(self, func: Callable[[], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
master_token = request.headers.get('Token', None)
if master_token is None:
return failure_response(
code=MasterErrorCode.SELF_TOKEN_NOT_GIVEN, message='Master token not found.'
), 400
elif master_token != self.__self_token:
return failure_response(
code=MasterErrorCode.SELF_TOKEN_INVALID, message='Master token not match with this endpoint.'
), 403
else:
return func()
return _func
# server method checkers (slave)
def __check_slave_request(self, func: Callable[[str, _ISlaveConnection], Any]) -> Callable[[], Any]:
return self.__check_shutdown(self.__check_channel(self.__check_slave_token(func)))
# noinspection DuplicatedCode
def __check_channel(self, func: Callable[[], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
channel = request.headers.get('Channel', None)
channel = int(channel) if channel else None
if channel is None:
return failure_response(code=MasterErrorCode.CHANNEL_NOT_GIVEN, message='Channel not found.'), 400
elif channel != self.__channel:
return failure_response(
code=MasterErrorCode.CHANNEL_INVALID, message='Channel not match with this endpoint.'
), 403
else:
return func()
return _func
def __check_slave_token(self, func: Callable[[str, _ISlaveConnection], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
slave_token = request.headers.get('Token', None)
if slave_token is None:
return failure_response(
code=MasterErrorCode.SLAVE_TOKEN_NOT_GIVEN, message='Slave token not found.'
), 400
elif slave_token not in self.__token_slaves.keys():
return failure_response(
code=MasterErrorCode.SLAVE_TOKEN_INVALID, message='No matching slave token found in this endpoint.'
), 403
else:
name, connection = self.__token_slaves[slave_token]
return func(name, connection)
return _func
# noinspection PyMethodMayBeStatic
def __get_request_data(self, func: Callable[[str, _ISlaveConnection, Mapping[str, Any]], Any]) \
-> Callable[[str, _ISlaveConnection], Any]:
@wraps(func)
def _func(name: str, connection: _ISlaveConnection):
_data = json.loads(request.data.decode())
return func(name, connection, _data)
return _func
def __check_task_info(self, func: Callable[[str, _ISlaveConnection, UUID, Mapping[str, Any]], Any]) \
-> Callable[[str, _ISlaveConnection], Any]:
@wraps(func)
@self.__get_request_data
def _func(name: str, connection: _ISlaveConnection, data: Mapping[str, Any]):
if 'task' not in data.keys():
return failure_response(
code=MasterErrorCode.TASK_DATA_INVALID,
message='Task information not found.',
)
_task_info, _task_result = data['task'], data['result']
if 'id' not in _task_info.keys():
return failure_response(code=MasterErrorCode.TASK_DATA_INVALID, message='Task ID not found.')
_task_id = UUID(_task_info['id'])
return func(name, connection, _task_id, _task_result)
return _func
# server methods (self)
# noinspection PyMethodMayBeStatic
def __self_ping(self):
return success_response(message='PONG!')
def __self_shutdown(self):
_shutdown_func = request.environ.get('werkzeug.server.shutdown')
if _shutdown_func is None:
raise RuntimeError('Not running with the Werkzeug Server')
self.__shutdown_event.set()
_shutdown_func()
return success_response(message='Shutdown request received, this server will be down later.')
# server methods (slave)
# noinspection PyMethodMayBeStatic,PyUnusedLocal
def __heartbeat(self, name: str, connection: _ISlaveConnection):
self.__slave_last_heartbeat[name] = time.time()
return success_response(message='Received!')
# noinspection PyUnusedLocal
def __task_complete(self, name: str, connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
self.__task_result_queue.put((TaskResultType.COMPLETED, (connection, task_id, task_result)))
return success_response(message='Result received!')
# noinspection PyUnusedLocal
def __task_fail(self, name: str, connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
self.__task_result_queue.put((TaskResultType.FAILED, (connection, task_id, task_result)))
return success_response(message='Result received!')
# self request
def __self_request(self, method: Optional[str] = 'GET', path: Optional[str] = None) -> requests.Response:
return self.__self_http_engine.request(method, path)
def __ping_once(self):
return self.__self_request('GET', '/ping')
def __ping_until_started(self):
while True:
try:
self.__ping_once()
except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
time.sleep(0.2)
else:
break
def __shutdown(self):
self.__self_request('DELETE', '/shutdown')
# heartbeat part
def __heartbeat_check(self):
_last_time = time.time()
while not self.__shutdown_event.is_set():
_current_time = time.time()
_common_names = set(self.__slaves.keys()) & set(self.__slave_last_heartbeat.keys())
for name in _common_names:
_, connection = self.__slaves[name]
last_heartbeat = self.__slave_last_heartbeat[name]
if _current_time - last_heartbeat > self.__heartbeat_tolerance:
self.__connection_close(name, connection)
_last_time += self.__heartbeat_check_span
time.sleep(_last_time - time.time())
# task process part
def __task_result_process(self):
while not self.__task_result_queue.empty() or not self.__shutdown_event.is_set():
try:
_result = self.__task_result_queue.get(timeout=3.0)
except Empty:
continue
else:
_type, (_connection, _task_id, _task_result) = _result
_trigger_func = _slave_task_complete if _type == TaskResultType.COMPLETED else _slave_task_fail
_trigger_func(_connection, _task_id, _task_result)
# connection part
def __get_connection_class(self) -> Type[SlaveConnection]:
return _get_connection_class(
before_new_task=self._before_new_task,
after_new_task=self._after_new_task,
error_new_task=self._error_new_task,
before_connect=self._before_connect,
after_connect=self._after_connect,
error_connect=self._error_connect,
before_disconnect=self._before_disconnect,
after_disconnect=self._after_disconnect,
error_disconnect=self._error_disconnect,
)
def __get_new_connection(
self, name: str, host: str, port: Optional[int] = None, https: bool = False
) -> SlaveConnectionProxy:
if name in self.__slaves.keys():
raise KeyError('Connection {name} already exist.'.format(name=repr(name)))
else:
slave_token = random_token()
connection = self.__get_connection_class()(
host=host,
port=port,
https=https,
channel=self.__channel,
my_address=self.__my_address,
token=slave_token,
)
return SlaveConnectionProxy(
connection=connection,
after_connect=partial(self.__connection_open, name=name, token=slave_token),
after_disconnect=partial(self.__connection_close, name=name),
)
# public properties
@property
def my_address(self) -> str:
with self.__lock:
return self.__my_address
# public methods
def ping(self) -> bool:
with self.__lock:
try:
self.__ping_once()
except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
return False
else:
return True
def new_connection(
self, name: str, host: str, port: Optional[int] = None, https: bool = False
) -> SlaveConnectionProxy:
with self.__lock:
return self.__get_new_connection(name, host, port, https)
def setup_server_conn(
self, host: str, port: int, api_version: str, namespace: str, name: str, https: bool = False
):
return ServerConnection(
host=host,
port=port,
api_version=api_version,
namespace=namespace,
name=name,
https=https
)
def __contains__(self, name: str):
with self.__lock:
return name in self.__slaves.keys()
def __getitem__(self, name: str):
with self.__lock:
if name in self.__slaves.keys():
_token, _connection = self.__slaves[name]
return _connection
else:
raise KeyError('Connection {name} not found.'.format(name=repr(name)))
def __delitem__(self, name: str):
with self.__lock:
if name in self.__slaves.keys():
_token, _connection = self.__slaves[name]
_connection.disconnect()
else:
raise KeyError('Connection {name} not found.'.format(name=repr(name)))
def start(self):
with self.__lock:
self.__task_result_process_thread.start()
self.__heartbeat_check_thread.start()
self.__run_app_thread.start()
self.__ping_until_started()
def shutdown(self):
with self.__lock:
self.__shutdown()
def join(self):
with self.__lock:
self.__run_app_thread.join()
self.__heartbeat_check_thread.join()
self.__task_result_process_thread.join()
# inherit methods
def _before_connect(self) -> Mapping[str, Any]:
pass
def _after_connect(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
pass
def _error_connect(self, error: RequestException):
raise error
def _before_disconnect(self) -> Mapping[str, Any]:
pass
def _after_disconnect(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
pass
def _error_disconnect(self, error: RequestException):
raise error
# noinspection PyMethodMayBeStatic
def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
return data or {}
def _after_new_task(
self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
) -> Any:
pass
def _error_new_task(self, error: RequestException):
# raise error
pass
from enum import unique, IntEnum
from threading import Lock
from typing import Mapping, Any, Optional, Callable
from uuid import UUID, uuid4
import requests
from requests import RequestException
from .base import _BEFORE_HOOK_TYPE, _AFTER_HOOK_TYPE, _ERROR_HOOK_TYPE
from ..base import HttpEngine, get_values_from_response, default_func
@unique
class TaskResultType(IntEnum):
COMPLETED = 1
FAILED = 2
@unique
class TaskStatus(IntEnum):
IDLE = 0x00
STARTING = 0x11
STARTED = 0x12
START_FAILED = 0x13
COMPLETED = 0x21
FAILED = 0x22
_COMPLETE_TRIGGER_NAME = '__TASK_COMPLETE__'
_FAIL_TRIGGER_NAME = '__TASK_FAIL__'
class Task:
def __init__(
self,
http_engine: HttpEngine,
data: Mapping[str, Any],
task_id: Optional[UUID] = None,
before_task_start: Optional[_BEFORE_HOOK_TYPE] = None,
after_task_start: Optional[_AFTER_HOOK_TYPE] = None,
error_task_start: Optional[_ERROR_HOOK_TYPE] = None
):
self.__http_engine = http_engine
self.__lock = Lock()
self.__task_id = task_id or uuid4()
self.__task_data = data
self.__task_result = None
self.__task_status = TaskStatus.IDLE
self.__task_lock = Lock()
self.__before_task_start = before_task_start or (lambda d: d)
self.__after_task_start = default_func(None)(after_task_start)
self.__error_task_start = default_func(None)(error_task_start)
self.__after_task_completed_callbacks = []
self.__after_task_failed_callbacks = []
self.__init_triggers()
def __request(self, method: str, path: str, data: Optional[Mapping[str, Any]] = None) -> requests.Response:
return self.__http_engine.request(method, path, data)
def __task_start(self):
try:
self.__task_status = TaskStatus.STARTING
response = self.__request(
'POST', '/task/new', {
'task': {
'id': str(self.__task_id)
},
'data': self.__before_task_start(self.__task_data) or {}
}
)
except RequestException as err:
self.__task_status = TaskStatus.START_FAILED
return self.__error_task_start(err)
else:
self.__task_status = TaskStatus.STARTED
ret = self.__after_task_start(*get_values_from_response(response))
self.__task_lock.acquire()
return ret
def __task_complete(self, result: Mapping[str, Any]):
self.__task_status = TaskStatus.COMPLETED
self.__task_result = result
for _callback in self.__after_task_completed_callbacks:
_callback(self.__task_data, result)
self.__task_lock.release()
def __task_fail(self, result: Mapping[str, Any]):
self.__task_status = TaskStatus.FAILED
self.__task_result = result
for _callback in self.__after_task_failed_callbacks:
_callback(self.__task_data, result)
self.__task_lock.release()
# trigger methods
def __task_complete_trigger(self, result: Mapping[str, Any]):
with self.__lock:
if self.__task_status == TaskStatus.STARTED:
self.__task_complete(result)
else:
raise ValueError(
"Only task with {expect} status can be completed, but {actual} found.".format(
expect=repr(TaskStatus.STARTED.name),
actual=repr(self.__task_status.name),
)
)
def __task_fail_trigger(self, result: Mapping[str, Any]):
with self.__lock:
if self.__task_status == TaskStatus.STARTED:
self.__task_fail(result)
else:
raise ValueError(
"Only task with {expect} status can be failed, but {actual} found.".format(
expect=repr(TaskStatus.STARTED.name),
actual=repr(self.__task_status.name),
)
)
def __init_triggers(self):
setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger)
setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger)
# public properties
@property
def status(self) -> TaskStatus:
return self.__task_status
@property
def task(self) -> Mapping[str, Any]:
return self.__task_data
@property
def result(self) -> Optional[Mapping[str, Any]]:
return self.__task_result
# public methods
def start(self) -> 'Task':
with self.__lock:
if self.__task_status == TaskStatus.IDLE:
self.__task_start()
return self
else:
raise ValueError(
"Only task with {expect} status can be started, but {actual} found.".format(
expect=repr(TaskStatus.IDLE.name),
actual=repr(self.__task_status.name),
)
)
def join(self) -> 'Task':
with self.__task_lock:
return self
def on_complete(self, callback: Callable[[Mapping[str, Any], Mapping[str, Any]], Any]) -> 'Task':
with self.__lock:
self.__after_task_completed_callbacks.append(callback)
return self
def on_fail(self, callback: Callable[[Mapping[str, Any], Mapping[str, Any]], Any]) -> 'Task':
with self.__lock:
self.__after_task_failed_callbacks.append(callback)
return self
def _task_complete(task: Task, result: Mapping[str, Any]):
getattr(task, _COMPLETE_TRIGGER_NAME)(result)
def _task_fail(task: Task, result: Mapping[str, Any]):
getattr(task, _FAIL_TRIGGER_NAME)(result)
from .action import TaskRefuse, DisconnectionRefuse, ConnectionRefuse, TaskFail
from .error_code import SlaveErrorCode
from .slave import Slave
from typing import Optional, Any, Mapping
from .error_code import SlaveErrorCode
from ..base import ResponsibleException
class ConnectionRefuse(ResponsibleException):
def __init__(self, data: Optional[Mapping[str, Any]] = None):
ResponsibleException.__init__(
self,
SlaveErrorCode.SLAVE_CONNECTION_REFUSED,
message='Connection refused!',
data=data or {},
status_code=403,
)
class DisconnectionRefuse(ResponsibleException):
def __init__(self, data: Optional[Mapping[str, Any]] = None):
ResponsibleException.__init__(
self,
SlaveErrorCode.SLAVE_DISCONNECTION_REFUSED,
message='Disconnection refused!',
data=data or {},
status_code=403,
)
class TaskRefuse(ResponsibleException):
def __init__(self, data: Optional[Mapping[str, Any]] = None):
ResponsibleException.__init__(
self,
SlaveErrorCode.TASK_REFUSED,
message='Task refused!',
data=data or {},
status_code=403,
)
class TaskFail(Exception):
def __init__(self, result: Optional[Mapping[str, Any]], message: Optional[str] = None):
if message:
Exception.__init__(self, 'Task process failed - {message}.'.format(message=message))
else:
Exception.__init__(self, 'Task process failed.')
self.__result = result or {}
@property
def result(self) -> Mapping[str, Any]:
return self.__result
from enum import unique, IntEnum
@unique
class SlaveErrorCode(IntEnum):
SUCCESS = 0
SYSTEM_SHUTTING_DOWN = 101
CHANNEL_NOT_FOUND = 201
CHANNEL_INVALID = 202
MASTER_TOKEN_NOT_FOUND = 301
MASTER_TOKEN_INVALID = 302
SELF_TOKEN_NOT_FOUND = 401
SELF_TOKEN_INVALID = 402
SLAVE_ALREADY_CONNECTED = 501
SLAVE_NOT_CONNECTED = 502
SLAVE_CONNECTION_REFUSED = 503
SLAVE_DISCONNECTION_REFUSED = 504
TASK_ALREADY_EXIST = 601
TASK_REFUSED = 602
import json
import sys
import time
import traceback
from abc import abstractmethod
from functools import wraps
from threading import Thread, Event, Lock
from typing import Optional, Callable, Any, Mapping, List
from uuid import UUID
import requests
from flask import Flask, request
from .action import ConnectionRefuse, DisconnectionRefuse, TaskRefuse, TaskFail
from .error_code import SlaveErrorCode
from ..base import random_token, ControllableService, get_http_engine_class, split_http_address, success_response, \
failure_response, DblEvent
from ..config import DEFAULT_SLAVE_PORT, DEFAULT_CHANNEL, GLOBAL_HOST, DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN
class Slave(ControllableService):
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
heartbeat_span: Optional[float] = None,
channel: Optional[int] = None
):
# server part
self.__host = host or GLOBAL_HOST
self.__port = port or DEFAULT_SLAVE_PORT
self.__flask_app_value = None
self.__run_app_thread = Thread(target=self.__run_app)
# heartbeat part
self.__heartbeat_span = max(heartbeat_span or DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN)
self.__heartbeat_thread = Thread(target=self.__heartbeat)
# task part
self.__has_task = DblEvent()
self.__task_lock = Lock()
self.__task_id = None
self.__task_data = None
self.__task_thread = Thread(target=self.__task)
# self-connection part
self.__self_http_engine = get_http_engine_class(headers={
'Token': lambda: self.__self_token,
# })()('localhost', self.__port, False)
})()(self.__host, self.__port, False)
self.__self_token = random_token()
# master-connection part
self.__channel = channel or DEFAULT_CHANNEL
self.__connected = DblEvent()
self.__master_token = None
self.__master_address = None
self.__master_http_engine = None
# global part
self.__shutdown_event = Event()
self.__lock = Lock()
# master connection
def __register_master(self, token: str, address: str):
self.__master_token = token
self.__master_address = address
self.__master_http_engine = get_http_engine_class(
headers={
'Channel': lambda: str(self.__channel),
'Token': lambda: self.__master_token,
}
)()(*split_http_address(self.__master_address))
def __unregister_master(self):
self.__master_token = None
self.__master_address = None
self.__master_http_engine = None
def __open_master_connection(self, token: str, address: str):
self.__register_master(token, address)
self.__connected.open()
def __close_master_connection(self):
self.__unregister_master()
self.__connected.close()
# server part
def __generate_app(self):
app = Flask(__name__)
# master apis
app.route('/connect', methods=['POST'])(self.__check_master_request(self.__connect, False))
app.route('/disconnect', methods=['DELETE'])(self.__check_master_request(self.__disconnect, True))
app.route('/task/new', methods=['POST'])(self.__check_master_request(self.__new_task, True))
# For aggregator only
app.route('/learners', methods=['POST'])(self.__check_master_request(self.__update_learners, True))
# self apis
app.route('/ping', methods=['GET'])(self.__check_self_request(self.__self_ping))
app.route('/shutdown', methods=['DELETE'])(self.__check_self_request(self.__self_shutdown))
return app
def __update_learners(self, token: str, data: List[str]):
return success_response(data=self._process_learners_update(data))
def __flask_app(self) -> Flask:
return self.__flask_app_value or self.__generate_app()
def __run_app(self):
while True:
try:
self.__flask_app().run(
host=self.__host,
port=self.__port,
)
except:
print("failed to run flask app on {}:{}..".format(self.__host, self.__port))
else:
break
# both method checkers
def __check_shutdown(self, func: Callable[[], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
if self.__shutdown_event.is_set():
return failure_response(
code=SlaveErrorCode.SYSTEM_SHUTTING_DOWN, message='System has already been shutting down.'
), 401
else:
return func()
return _func
# server method checkers (master)
def __check_master_request(self,
func: Callable[[str, Mapping[str, Any]], Any],
need_match: bool = True) -> Callable[[], Any]:
return self.__check_shutdown(self.__check_channel(self.__check_master_token(func, need_match)))
# noinspection DuplicatedCode
def __check_channel(self, func: Callable[[], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
channel = request.headers.get('Channel', None)
channel = int(channel) if channel else None
if channel is None:
return failure_response(code=SlaveErrorCode.CHANNEL_NOT_FOUND, message='Channel not found.'), 400
elif channel != self.__channel:
return failure_response(
code=SlaveErrorCode.CHANNEL_INVALID, message='Channel not match with this endpoint.'
), 403
else:
return func()
return _func
def __check_master_token(self,
func: Callable[[str, Mapping[str, Any]], Any],
need_match: bool = True) -> Callable[[], Any]:
@wraps(func)
def _func():
master_token = request.headers.get('Token', None)
if master_token is None:
return failure_response(
code=SlaveErrorCode.MASTER_TOKEN_NOT_FOUND, message='Master token not found.'
), 400
elif need_match and (master_token != self.__master_token):
return failure_response(
code=SlaveErrorCode.MASTER_TOKEN_INVALID, message='Master not match with this endpoint.'
), 403
else:
return func(master_token, json.loads(request.data.decode()))
return _func
# server method checkers (self)
# noinspection DuplicatedCode
def __check_self_request(self, func: Callable[[], Any]) -> Callable[[], Any]:
return self.__check_shutdown(self.__check_slave_token(func))
def __check_slave_token(self, func: Callable[[], Any]) -> Callable[[], Any]:
@wraps(func)
def _func():
slave_token = request.headers.get('Token', None)
if slave_token is None:
return failure_response(code=SlaveErrorCode.SELF_TOKEN_NOT_FOUND, message='Slave token not found.'), 400
elif slave_token != self.__self_token:
return failure_response(
code=SlaveErrorCode.SELF_TOKEN_INVALID, message='Slave token not match with this endpoint.'
), 403
else:
return func()
return _func
# server methods (self)
# noinspection PyMethodMayBeStatic
def __self_ping(self):
return success_response(message='PONG!')
def __self_shutdown(self):
_shutdown_func = request.environ.get('werkzeug.server.shutdown')
if _shutdown_func is None:
raise RuntimeError('Not running with the Werkzeug Server')
self.__shutdown_event.set()
_shutdown_func()
return success_response(message='Shutdown request received, this server will be down later.')
# server methods (master)
# noinspection PyUnusedLocal
def __connect(self, token: str, data: Mapping[str, Any]):
if self.__connected.is_open():
return failure_response(
code=SlaveErrorCode.SLAVE_ALREADY_CONNECTED, message='This slave already connected.'
), 400
else:
_master_info, _connection_data = data['master'], data['data']
try:
self._before_connection(_connection_data)
except ConnectionRefuse as err:
return err.get_response()
else:
self.__open_master_connection(token, _master_info['address'])
return success_response(message='Connect success.')
# noinspection PyUnusedLocal
def __new_task(self, token: str, data: Mapping[str, Any]):
with self.__task_lock:
if self.__has_task.is_open():
return failure_response(code=SlaveErrorCode.TASK_ALREADY_EXIST, message='Already has a task.'), 400
else:
_task_info, _task_data = data['task'], data['data']
_task_id = _task_info['id']
try:
self._before_task(_task_data)
except TaskRefuse as err:
return err.get_response()
else:
self.__task_id = UUID(_task_id)
self.__task_data = _task_data
self.__has_task.open()
return success_response(message='Task received!')
# noinspection PyUnusedLocal
def __disconnect(self, token: str, data: Mapping[str, Any]):
if self.__connected.is_close():
return failure_response(
code=SlaveErrorCode.SLAVE_NOT_CONNECTED, message='This slave not connected yet.'
), 400
else:
_disconnection_data = data['data']
try:
self._before_disconnection(_disconnection_data)
except DisconnectionRefuse as err:
return err.get_response()
else:
self.__close_master_connection()
return success_response(message='Disconnect success.')
# heartbeat part
def __heartbeat(self):
_last_time = time.time()
while not self.__shutdown_event.is_set():
if self.__connected.is_open():
try:
self.__master_heartbeat()
except requests.exceptions.RequestException as err:
self._lost_connection(self.__master_address, err)
self.__close_master_connection()
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
_last_time += self.__heartbeat_span
time.sleep(max(_last_time - time.time(), 0))
# task part
def __task(self):
while not self.__shutdown_event.is_set():
self.__has_task.wait_for_open(timeout=1.0)
if self.__has_task.is_open():
# noinspection PyBroadException
try:
result = self._process_task(self.__task_data)
except TaskFail as fail:
self.__has_task.close()
self.__master_task_fail(fail.result)
except Exception:
self.__has_task.close()
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
else:
self.__has_task.close()
self.__master_task_complete(result)
# self request operations
def __self_request(self, method: Optional[str] = 'GET', path: Optional[str] = None) -> requests.Response:
return self.__self_http_engine.request(method, path)
def __ping_once(self):
return self.__self_request('GET', '/ping')
def __ping_until_started(self):
while True:
try:
self.__ping_once()
except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
time.sleep(0.2)
else:
break
def __shutdown(self):
self.__self_request('DELETE', '/shutdown')
# master request operations
def __master_request(
self,
method: Optional[str] = 'GET',
path: Optional[str] = None,
data: Optional[Mapping[str, Any]] = None
) -> requests.Response:
return self.__master_http_engine.request(method, path, data)
def __master_heartbeat(self):
return self.__master_request('GET', '/slave/heartbeat')
def __master_task_complete(self, result: Mapping[str, Any]):
return self.__master_request(
'PUT', '/slave/task/complete', data={
'task': {
'id': str(self.__task_id)
},
'result': result or {},
}
)
def __master_task_fail(self, result: Mapping[str, Any]):
return self.__master_request(
'PUT', '/slave/task/fail', data={
'task': {
'id': str(self.__task_id)
},
'result': result or {},
}
)
# public methods
def ping(self) -> bool:
with self.__lock:
try:
self.__ping_once()
except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
return False
else:
return True
def start(self):
with self.__lock:
self.__task_thread.start()
self.__heartbeat_thread.start()
self.__run_app_thread.start()
self.__ping_until_started()
def shutdown(self):
with self.__lock:
self.__shutdown()
def join(self):
with self.__lock:
self.__run_app_thread.join()
self.__heartbeat_thread.join()
self.__task_thread.join()
# inherit method
def _before_connection(self, data: Mapping[str, Any]):
pass
def _before_disconnection(self, data: Mapping[str, Any]):
pass
def _before_task(self, data: Mapping[str, Any]):
pass
def _lost_connection(self, master_address: str, err: requests.exceptions.RequestException):
pass
# For aggregator only
def _process_learners_update(self, replicas: Mapping[str, Any]):
raise NotImplementedError
@abstractmethod
def _process_task(self, task: Mapping[str, Any]):
raise NotImplementedError
from .base import *
from .config import *
from .interaction import *
from .test_app import TestInteractionBaseApp, TestInteractionBaseResponsibleException
from .test_common import TestInteractionBaseCommon, TestInteractionBaseControllableService
from .test_network import TestInteractionBaseHttpEngine, TestInteractionBaseNetwork
from .test_threading import TestInteractionBaseThreading
import json
import pytest
from flask import Flask
from ...base import success_response, failure_response, get_values_from_response, ResponsibleException, responsible
@pytest.mark.unittest
class TestInteractionBaseApp:
def test_success_response(self):
app = Flask('_test_success_response')
@app.route('/success', methods=['GET'])
def success_method():
return success_response(
data={
'a': 1,
'b': 2,
'sum': 3,
},
message='This is success message.',
)
client = app.test_client()
response = client.get('/success')
assert response.status_code == 200
assert json.loads(response.data.decode()) == {
'success': True,
'code': 0,
'data': {
'a': 1,
'b': 2,
'sum': 3,
},
'message': 'This is success message.',
}
# noinspection DuplicatedCode
def test_failure_response(self):
app = Flask('_test_failure_response')
@app.route('/fail', methods=['GET'])
def fail_method():
return failure_response(
code=233,
message='This is failure message.',
data={
'a': 2,
'b': 3,
'sum': 5,
},
), 404
client = app.test_client()
response = client.get('/fail')
assert response.status_code == 404
assert json.loads(response.data.decode()) == {
'success': False,
'code': 233,
'data': {
'a': 2,
'b': 3,
'sum': 5,
},
'message': 'This is failure message.',
}
def test_get_values_from_response(self):
app = Flask('_test_get_values_from_response')
@app.route('/success', methods=['GET'])
def success_method():
return success_response(
data={
'a': 1,
'b': 2,
'sum': 3,
},
message='This is success message.',
)
@app.route('/fail', methods=['GET'])
def fail_method():
return failure_response(
code=233,
message='This is failure message.',
data={
'a': 2,
'b': 3,
'sum': 5,
},
), 404
client = app.test_client()
response = client.get('/success')
assert response.status_code == 200
assert get_values_from_response(response) == (
200,
True,
0,
'This is success message.',
{
'a': 1,
'b': 2,
'sum': 3,
},
)
response = client.get('/fail')
assert response.status_code == 404
assert get_values_from_response(response) == (
404,
False,
233,
'This is failure message.',
{
'a': 2,
'b': 3,
'sum': 5,
},
)
@pytest.mark.unittest
class TestInteractionBaseResponsibleException:
# noinspection DuplicatedCode
def test_it(self):
class NotFound(ResponsibleException):
def __init__(self):
ResponsibleException.__init__(
self=self,
status_code=404,
code=233,
message='This is failure message.',
data={
'a': 2,
'b': 3,
'sum': 5,
}
)
class AccessDenied(ResponsibleException):
def __init__(self):
ResponsibleException.__init__(
self=self,
status_code=403,
code=322,
message='This is another failure message.',
data={
'a': 2,
'b': 3,
'sum': 7,
}
)
app = Flask('_test_failure_response')
@app.route('/fail', methods=['GET'])
@responsible(classes=(NotFound, ))
def fail_method():
raise NotFound
@app.route('/403', methods=['GET'])
@responsible()
def denied_method():
raise AccessDenied
@app.route('/success', methods=['GET'])
@responsible()
def success_method():
return success_response(
data={
'a': 1,
'b': 2,
'sum': 3,
},
message='This is success message.',
)
client = app.test_client()
response = client.get('/fail')
assert response.status_code == 404
assert json.loads(response.data.decode()) == {
'success': False,
'code': 233,
'data': {
'a': 2,
'b': 3,
'sum': 5,
},
'message': 'This is failure message.',
}
response = client.get('/403')
assert response.status_code == 403
assert json.loads(response.data.decode()) == {
'success': False,
'code': 322,
'data': {
'a': 2,
'b': 3,
'sum': 7,
},
'message': 'This is another failure message.',
}
response = client.get('/success')
assert response.status_code == 200
assert json.loads(response.data.decode()) == {
'success': True,
'code': 0,
'data': {
'a': 1,
'b': 2,
'sum': 3,
},
'message': 'This is success message.',
}
import string
import time
from typing import Any, Callable
import pytest
from ...base import random_token, translate_dict_func, default_func, ControllableService
@pytest.mark.unittest
class TestInteractionBaseCommon:
def test_random_token(self):
assert len(random_token()) == 64
assert len(random_token(32)) == 32
assert set(random_token()) - set(string.hexdigits) == set()
def test_translate_dict_func(self):
assert translate_dict_func({
'a': lambda: 2,
'b': lambda: 3,
'sum': lambda: 5,
})() == {
'a': 2,
'b': 3,
'sum': 5
}
assert translate_dict_func(
{
'a': lambda ax, bx: 2 + ax,
'b': lambda ax, bx: 3 + bx,
'sum': lambda ax, bx: 5 + ax + bx,
}
)(4, 5) == {
'a': 6,
'b': 8,
'sum': 14
}
def test_default_func(self):
def _calculate(a: int, b: int, callback: Callable[..., Any] = None):
return default_func(233)(callback)(a, b)
assert _calculate(1, 2) == 233
assert _calculate(1, 2, lambda a, b: a + b) == 3
assert _calculate(1, 2, lambda a, b: a * b) == 2
@pytest.mark.unittest
class TestInteractionBaseControllableService:
def test_it(self):
_start, _shutdown, _finished = False, False, False
class _Service(ControllableService):
def start(self):
nonlocal _start
_start = True
def shutdown(self):
nonlocal _shutdown
_shutdown = True
def join(self):
time.sleep(1.0)
nonlocal _finished
_finished = True
assert (_start, _shutdown, _finished) == (False, False, False)
with _Service():
assert (_start, _shutdown, _finished) == (True, False, False)
assert (_start, _shutdown, _finished) == (True, True, True)
import time
from threading import Thread
import pytest
from ...base import DblEvent
@pytest.mark.unittest
class TestInteractionBaseThreading:
# noinspection DuplicatedCode
@pytest.mark.execution_timeout(5.0, method='thread')
def test_dbl_event_open(self):
event = DblEvent()
assert event.is_close()
assert not event.is_open()
# Opening test
_time_1, _time_2 = 0.0, 0.0
def _run_1_wait_for_open():
nonlocal _time_1
event.wait_for_open()
_time_1 = time.time()
def _run_2_wait_for_open():
nonlocal _time_2
event.wait_for_open()
_time_2 = time.time()
_thread_1 = Thread(target=_run_1_wait_for_open)
_thread_2 = Thread(target=_run_2_wait_for_open)
_thread_1.start()
_thread_2.start()
time.sleep(0.2)
assert event.is_close()
assert not event.is_open()
assert _time_1 == 0.0
assert _time_2 == 0.0
time.sleep(0.8)
event.open()
_thread_1.join()
_thread_2.join()
assert abs(time.time() - _time_1) < 0.3
assert abs(time.time() - _time_2) < 0.3
assert not event.is_close()
assert event.is_open()
# Closing test
_time_1, _time_2 = 0.0, 0.0
def _run_1_wait_for_close():
nonlocal _time_1
event.wait_for_close()
_time_1 = time.time()
def _run_2_wait_for_close():
nonlocal _time_2
event.wait_for_close()
_time_2 = time.time()
_thread_1 = Thread(target=_run_1_wait_for_close)
_thread_2 = Thread(target=_run_2_wait_for_close)
_thread_1.start()
_thread_2.start()
time.sleep(0.2)
assert not event.is_close()
assert event.is_open()
assert _time_1 == 0.0
assert _time_2 == 0.0
time.sleep(0.8)
event.close()
_thread_1.join()
_thread_2.join()
assert abs(time.time() - _time_1) < 0.3
assert abs(time.time() - _time_2) < 0.3
assert event.is_close()
assert not event.is_open()
# noinspection DuplicatedCode
@pytest.mark.execution_timeout(5.0, method='thread')
def test_dbl_event_close(self):
event = DblEvent(True)
assert not event.is_close()
assert event.is_open()
# Closing test
_time_1, _time_2 = 0.0, 0.0
def _run_1_wait_for_close():
nonlocal _time_1
event.wait_for_close()
_time_1 = time.time()
def _run_2_wait_for_close():
nonlocal _time_2
event.wait_for_close()
_time_2 = time.time()
_thread_1 = Thread(target=_run_1_wait_for_close)
_thread_2 = Thread(target=_run_2_wait_for_close)
_thread_1.start()
_thread_2.start()
time.sleep(0.2)
assert not event.is_close()
assert event.is_open()
assert _time_1 == 0.0
assert _time_2 == 0.0
time.sleep(0.8)
event.close()
_thread_1.join()
_thread_2.join()
assert abs(time.time() - _time_1) < 0.3
assert abs(time.time() - _time_2) < 0.3
assert event.is_close()
assert not event.is_open()
from .test_base import TestInteractionConfig
import pytest
from ...config import GLOBAL_HOST, LOCAL_HOST
@pytest.mark.unittest
class TestInteractionConfig:
def test_base_host(self):
assert GLOBAL_HOST == '0.0.0.0'
assert LOCAL_HOST == '127.0.0.1'
from .test_errors import TestInteractionErrors
from .test_simple import TestInteractionSimple
from .random import random_port, random_channel
from .stream import silence, silence_function
import random
from typing import Iterable
def random_port(excludes: Iterable[int] = None) -> int:
return random.choice(list(set(range(10000, 20000)) - set(excludes or [])))
def random_channel(excludes: Iterable[int] = None) -> int:
excludes = set(list(excludes or []))
while True:
_channel = random.randint(1000, (1 << 31) - 1)
if _channel not in excludes:
return _channel
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
from .learner_aggregator import MockLearnerAggregator
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部