run_node_classification.sh 2.0 KB
Newer Older
W
weiyue.su 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#!/bin/bash 

set -x
config=${1:-"./config.yaml"}
unset http_proxy https_proxy

function parse_yaml {
   local prefix=$2
   local s='[[:space:]]*' w='[a-zA-Z0-9_]*' fs=$(echo @|tr @ '\034')
   sed -ne "s|^\($s\):|\1|" \
        -e "s|^\($s\)\($w\)$s:$s[\"']\(.*\)[\"']$s\$|\1$fs\2$fs\3|p" \
        -e "s|^\($s\)\($w\)$s:$s\(.*\)$s\$|\1$fs\2$fs\3|p"  $1 |
   awk -F$fs '{
      indent = length($1)/2;
      vname[indent] = $2;
      for (i in vname) {if (i > indent) {delete vname[i]}}
      if (length($3) > 0) {
         vn=""; for (i=0; i<indent; i++) {vn=(vn)(vname[i])("_")}
         printf("%s%s%s=\"%s\"\n", "'$prefix'",vn, $2, $3);
      }
   }'
}

transpiler_local_train(){
    export PADDLE_TRAINERS_NUM=1
    export PADDLE_PSERVERS_NUM=1
    export PADDLE_PORT=6206
    export PADDLE_PSERVERS="127.0.0.1"
    export BASE="./local_dir"
    echo `which python`
    if [ -d ${BASE} ]; then
        rm -rf ${BASE}
    fi 
    mkdir ${BASE}
    rm job_id
    for((i=0;i<${PADDLE_PSERVERS_NUM};i++))
    do
        echo "start ps server: ${i}"
S
suweiyue 已提交
39
        TRAINING_ROLE="PSERVER" PADDLE_TRAINER_ID=${i} python ./train.py --conf $config \
W
weiyue.su 已提交
40 41 42 43 44 45 46
            &> $BASE/pserver.$i.log &
        echo $! >> job_id
    done
    sleep 3s 
    for((j=0;j<${PADDLE_TRAINERS_NUM};j++))
    do
        echo "start ps work: ${j}"
S
suweiyue 已提交
47 48
        TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} python ./train.py --conf $config 
        TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} python ./infer.py --conf $config 
W
weiyue.su 已提交
49 50 51 52 53
    done
}

collective_local_train(){
    echo `which python`
54 55
    python -m paddle.distributed.launch node_classification_train.py --conf $config
    python -m paddle.distributed.launch node_classification_infer.py --conf $config
W
weiyue.su 已提交
56 57 58 59
}

eval $(parse_yaml $config)

60 61
python ./preprocessing/dump_graph.py -i $train_data -g $graph_data -o $graph_work_path \
    --encoding $encoding -l $max_seqlen --vocab_file $ernie_vocab_file --task $task
W
weiyue.su 已提交
62 63 64 65 66 67 68

if [[ $learner_type == "cpu" ]];then
    transpiler_local_train
fi
if [[ $learner_type == "gpu" ]];then
    collective_local_train
fi