tf_k8s 2.1 KB
Newer Older
G
gongweibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
#!/bin/bash
check_trainer_ret() {
  ret=$1
  stdbuf -oL echo "job returned $ret...setting pod return message..."
  stdbuf -oL echo "==============================="

  if [ $ret -eq 136 ] ; then
    echo "Error Arithmetic Operation(Floating Point Exception)" > /dev/termination-log
  elif [ $ret -eq 139 ] ; then
    echo "Segmentation Fault" > /dev/termination-log
  elif [ $ret -eq 1 ] ; then
    echo "General Error" > /dev/termination-log
  elif [ $ret -eq 134 ] ; then
    echo "Program Abort" > /dev/termination-log
  fi
  stdbuf -oL echo "termination log wroted..."
  exit $ret
}

g_pservers=""
g_trainers=""

wait_running_pods(){
  pserver_label="tf-job-pserver=${JOB_NAME}"
  trainer_label="tf-job-trainer=${JOB_NAME}"

  stdbuf -oL python /root/k8s_tools.py wait_pods_running ${pserver_label} ${PSERVERS_NUM}
  stdbuf -oL python /root/k8s_tools.py wait_pods_running ${trainer_label} ${TRAINERS_NUM}

  g_pservers=$(python /root/k8s_tools.py fetch_endpoints ${pserver_label} ${PORT})
  g_trainers=$(python /root/k8s_tools.py fetch_endpoints ${trainer_label} ${PORT})
}

start_tf_pserver(){
  wait_running_pods

  label="tf-job-pserver=${JOB_NAME}"
  pserver_id=$(python /root/k8s_tools.py fetch_id ${label})

  cmd="${ENTRY} --ps_hosts=${g_pservers} --worker_hosts=${g_trainers} \
  --job_name=${TF_JOB_NAME} --task_index=${pserver_id}"

  stdbuf -oL sh -c "cd ${TRAINER_PACKAGE} && ${cmd}"
}

start_tf_trainer(){
  wait_running_pods

  label="tf-job-trainer=${JOB_NAME}"
  trainer_id=$(python /root/k8s_tools.py fetch_id ${label})

  cmd="${ENTRY} --ps_hosts=${g_pservers} --worker_hosts=${g_trainers} \
  --job_name=${TF_JOB_NAME} --task_index=${trainer_id} --batch_size=${BATCH_SIZE}"

  stdbuf -oL sh -c "cd ${TRAINER_PACKAGE} && ${cmd}"
  check_trainer_ret $?
}

start_tf(){
    if [[ "${TF_JOB_NAME}" == "worker" ]]; then
        start_tf_trainer
    else
        start_tf_pserver
    fi
}

usage() {
    echo "usage: tf_k8s [<args>]:"
    echo "  start_tf         Start tensorflow jobs"
}

case "$1" in
    start_tf)
        start_tf
        ;;
    --help)
        usage
        ;;
    *)
        usage
        ;;
esac