未验证 提交 32a142aa 编写于 作者: X Xuefeng Xu 提交者: GitHub

add fetch task status in python sdk (#501)

* add fetch task status in python sdk

* fix bug when getting async status

* exit if task fail or nonexist
上级 4727be1f
......@@ -13,10 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import uuid
import time
import grpc
from src.primihub.protos import common_pb2, worker_pb2, worker_pb2_grpc # noqa
from .src.primihub.protos import common_pb2, worker_pb2, worker_pb2_grpc # noqa
class WorkerClient:
......@@ -39,48 +39,32 @@ class WorkerClient:
party_name: str = None,
party_datasets: dict = None,
party_access_info: dict = None) -> None:
"""Constructor
"""
self.node = node
self.cert = cert
self.channel = grpc.insecure_channel(node)
self.request_data = None
self.stub = worker_pb2_grpc.VMNodeStub(self.channel)
self.task_type = task_type
self.name = task_name
self.language = language
self.params = params
self.code = code
self.node_map = node_map
self.input_datasets = input_datasets
self.task_info = task_info
self.party_name = party_name
self.party_datasets = party_datasets
self.party_access_info = party_access_info
def set_request_map(self):
request_map = {
"type": self.task_type,
"name": self.name,
"language": self.language,
"params": self.params,
"code": self.code,
"node_map": self.node_map,
"input_datasets": self.input_datasets,
"task_info": self.task_info,
"party_name": self.party_name,
"party_datasets": self.party_datasets,
"party_access_info": self.party_access_info
self.stub = worker_pb2_grpc.VMNodeStub(self.channel)
self.task = {
"type": task_type,
"name": task_name,
"language": language,
"params": params,
"code": code,
"node_map": node_map,
"input_datasets": input_datasets,
"task_info": task_info,
"party_name": party_name,
"party_datasets": party_datasets,
"party_access_info": party_access_info
}
self.request_map = request_map
@staticmethod
def push_task_request(intended_worker_id=b'1',
def push_task_request(self,
intended_worker_id=None,
task=None,
sequence_number=11,
client_processed_up_to=22,
submit_client_id=b""):
sequence_number=None,
client_processed_up_to=None,
submit_client_id=None):
request_data = {
"intended_worker_id": intended_worker_id,
"task": task,
......@@ -88,26 +72,71 @@ class WorkerClient:
"client_processed_up_to": client_processed_up_to,
"submit_client_id": submit_client_id
}
print(
f"########################The request_data is {request_data}##################"
)
request = worker_pb2.PushTaskRequest(**request_data)
return request
def submit_task(
self,
request: worker_pb2.PushTaskRequest) -> worker_pb2.PushTaskReply:
def submit_task(self):
"""gRPC submit task
:returns: gRPC reply
:rtype: :obj:`worker_pb2.PushTaskReply`
"""
# print(type(request_data), request_data)
self.set_request_map()
request = WorkerClient.push_task_request(task=self.request_map)
PushTaskRequest = self.push_task_request(task=self.task)
print(PushTaskRequest)
print(20*'-')
with self.channel:
reply = self.stub.SubmitTask(request)
print("return code: %s, job id: %s" %
(reply.ret_code, reply.task_info.job_id)) # noqa
return reply
PushTaskReply = self.stub.SubmitTask(PushTaskRequest)
start_time = time.time()
ret_code_map = {0: 'success', 1: 'doing', 2: 'error'}
print('ret_code:', ret_code_map[PushTaskReply.ret_code])
print('task_info:', PushTaskReply.task_info)
print('party_count:', PushTaskReply.party_count)
print('task_server:', PushTaskReply.task_server)
print(20*'-')
task_info = PushTaskReply.task_info
status_map = {
0: 'RUNNING',
1: 'SUCCESS',
2: 'FAIL',
3: 'NONEXIST',
4: "FINISHED"
}
party_status = {}
is_fail = False
while True:
time.sleep(1)
TaskStatusReply = self.stub.FetchTaskStatus(task_info)
for task_status in TaskStatusReply.task_status:
party = task_status.party
status = task_status.status
if status == worker_pb2.TaskStatus.StatusCode.FAIL or \
status == worker_pb2.TaskStatus.StatusCode.NONEXIST:
is_fail = True
if is_fail:
break
if party:
print('party:', party)
print('status:', status_map[status])
print(20*'-')
if status != worker_pb2.TaskStatus.StatusCode.RUNNING:
party_status[party] = status_map[status]
if is_fail or len(party_status) == PushTaskReply.party_count:
break
end_time = time.time()
print(f'time spend: {end_time - start_time:.3f} s')
if is_fail:
print(f"fail: {TaskStatusReply}")
else:
print(f'status: {party_status}')
......@@ -7,9 +7,11 @@ import uuid
class Client:
def __init__(self, json_file, var_type=common_pb2.VarType.STRING, is_array=False):
# json_file contains three components:
# json_file: party_info, component_params
self.party_info = json_file['party_info']
self.component_params = json_file['component_params']
# component_params: common_params, role_params
self.common_params = self.component_params['common_params']
self.role_params = self.component_params['role_params']
self.var_type = var_type
self.is_array = is_array
......@@ -27,7 +29,7 @@ class Client:
party_datasets = {}
for party_name, role_param in self.role_params.items():
Dataset = common_pb2.Dataset()
Dataset.data['data_set'] = role_param.get('data_set','')
Dataset.data['data_set'] = role_param.get('data_set','{}')
party_datasets[party_name] = Dataset
# construct 'task_info'
......@@ -46,10 +48,10 @@ class Client:
Node.use_tls = party_info['use_tls']
party_access_info[party_name] = Node
self.current_worker = WorkerClient(
self.worker = WorkerClient(
node=self.party_info['task_manager'],
cert=None,
task_name=self.component_params['common_params'].get('task_name',''),
task_name=self.common_params.get('task_name',''),
language=common_pb2.Language.PYTHON,
params=params,
task_info=task_info,
......@@ -58,8 +60,4 @@ class Client:
def submit(self):
self.prepare_for_worker()
reply = self.current_worker.submit_task(request=None)
def get_status(self, task_id):
#Todo
pass
\ No newline at end of file
self.worker.submit_task()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册