提交 2839bccc 编写于 作者: S suweiyue

1.pad remove; 2.cpu mode

上级 8e6ea314
......@@ -4,9 +4,9 @@
learner_type: "cpu"
optimizer_type: "adam"
lr: 0.00005
batch_size: 2
CPU_NUM: 10
epoch: 20
batch_size: 4
CPU_NUM: 16
epoch: 3
log_per_step: 1
save_per_step: 100
output_path: "./output"
......@@ -31,6 +31,7 @@ final_fc: true
final_l2_norm: true
loss_type: "hinge"
margin: 0.3
neg_type: "random_neg"
# infer config ------
infer_model: "./output/last"
......
......@@ -183,5 +183,6 @@ if __name__ == "__main__":
parser.add_argument("--conf", type=str, default="./config.yaml")
args = parser.parse_args()
config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
config.loss_type = "hinge"
print(config)
main(config)
unset http_proxy https_proxy
set -x
mode=${1:-local}
config=${2:-"./config.yaml"}
function parse_yaml {
local prefix=$2
local s='[[:space:]]*' w='[a-zA-Z0-9_]*' fs=$(echo @|tr @ '\034')
sed -ne "s|^\($s\):|\1|" \
-e "s|^\($s\)\($w\)$s:$s[\"']\(.*\)[\"']$s\$|\1$fs\2$fs\3|p" \
-e "s|^\($s\)\($w\)$s:$s\(.*\)$s\$|\1$fs\2$fs\3|p" $1 |
awk -F$fs '{
indent = length($1)/2;
vname[indent] = $2;
for (i in vname) {if (i > indent) {delete vname[i]}}
if (length($3) > 0) {
vn=""; for (i=0; i<indent; i++) {vn=(vn)(vname[i])("_")}
printf("%s%s%s=\"%s\"\n", "'$prefix'",vn, $2, $3);
}
}'
}
eval $(parse_yaml $config)
export CPU_NUM=$CPU_NUM
export FLAGS_rpc_deadline=3000000
export FLAGS_rpc_retry_times=1000
if [[ $async_mode == "True" ]];then
echo "async_mode is True"
else
export FLAGS_communicator_send_queue_size=1
export FLAGS_communicator_min_send_grad_num_before_recv=0
export FLAGS_communicator_max_merge_var_num=1 # important!
export FLAGS_communicator_merge_sparse_grad=0
fi
export FLAGS_communicator_recv_wait_times=5000000
mkdir -p output
python ./train.py --conf $config
if [[ $TRAINING_ROLE == "TRAINER" ]];then
python ./infer.py --conf $config
fi
......@@ -26,6 +26,17 @@ from paddle.fluid.incubate.fleet.collective import fleet as cfleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as tfleet
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from tensorboardX import SummaryWriter
from paddle.fluid.transpiler.distribute_transpiler import DistributedMode
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig
# hack it!
base_get_communicator_flags = TrainerRuntimeConfig.get_communicator_flags
def get_communicator_flags(self):
flag_dict = base_get_communicator_flags(self)
flag_dict['communicator_max_merge_var_num'] = str(1)
flag_dict['communicator_send_queue_size'] = str(1)
return flag_dict
TrainerRuntimeConfig.get_communicator_flags = get_communicator_flags
class Learner(object):
......@@ -132,8 +143,6 @@ class TranspilerLearner(Learner):
self.model = model
def optimize(self, loss, optimizer_type, lr):
strategy = DistributeTranspilerConfig()
strategy.sync_mode = False
log.info('learning rate:%f' % lr)
if optimizer_type == "sgd":
optimizer = F.optimizer.SGD(learning_rate=lr)
......@@ -143,7 +152,8 @@ class TranspilerLearner(Learner):
else:
raise ValueError("Unknown Optimizer %s" % optimizer_type)
#create the DistributeTranspiler configure
optimizer = tfleet.distributed_optimizer(optimizer, strategy)
self.strategy = StrategyFactory.create_sync_strategy()
optimizer = tfleet.distributed_optimizer(optimizer, self.strategy)
optimizer.minimize(loss)
def init_and_run_ps_worker(self, ckpt_path):
......
......@@ -36,7 +36,7 @@ transpiler_local_train(){
for((i=0;i<${PADDLE_PSERVERS_NUM};i++))
do
echo "start ps server: ${i}"
TRAINING_ROLE="PSERVER" PADDLE_TRAINER_ID=${i} sh job.sh local $config \
TRAINING_ROLE="PSERVER" PADDLE_TRAINER_ID=${i} python ./train.py --conf $config \
&> $BASE/pserver.$i.log &
echo $! >> job_id
done
......@@ -44,8 +44,8 @@ transpiler_local_train(){
for((j=0;j<${PADDLE_TRAINERS_NUM};j++))
do
echo "start ps work: ${j}"
TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} sh job.sh local $config \
echo $! >> job_id
TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} python ./train.py --conf $config
TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} python ./infer.py --conf $config
done
}
......
......@@ -19,8 +19,6 @@ from contextlib import contextmanager
import paddle.fluid as fluid
import paddle.fluid.layers as L
import paddle.fluid.layers as layers
#import propeller.paddle as propeller
#from propeller import log
#determin this at the begining
to_3d = lambda a: a # will change later
......@@ -85,7 +83,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -262,7 +260,6 @@ def encoder_layer(enc_input,
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
#L.Print(L.reduce_mean(enc_input), message='1')
attn_output, ctx_multiheads_attn = multi_head_attention(
pre_process_layer(
enc_input,
......@@ -279,7 +276,6 @@ def encoder_layer(enc_input,
attention_dropout,
param_initializer=param_initializer,
name=name + '_multi_head_att')
#L.Print(L.reduce_mean(attn_output), message='1')
attn_output = post_process_layer(
enc_input,
attn_output,
......@@ -287,7 +283,6 @@ def encoder_layer(enc_input,
prepostprocess_dropout,
name=name + '_post_att')
#L.Print(L.reduce_mean(attn_output), message='2')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
......@@ -300,14 +295,12 @@ def encoder_layer(enc_input,
hidden_act,
param_initializer=param_initializer,
name=name + '_ffn')
#L.Print(L.reduce_mean(ffd_output), message='3')
ret = post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_ffn')
#L.Print(L.reduce_mean(ret), message='4')
return ret, ctx_multiheads_attn, ffd_output
......@@ -374,7 +367,7 @@ def encoder(enc_input,
encoder_layer.
"""
#global to_2d, to_3d #, batch, seqlen, dynamic_dim
global to_2d, to_3d #, batch, seqlen, dynamic_dim
d_shape = L.shape(input_mask)
pad_idx = build_pad_idx(input_mask)
attn_bias = build_attn_bias(input_mask, n_head, enc_input.dtype)
......@@ -391,14 +384,14 @@ def encoder(enc_input,
# if attn_bias.dtype != enc_input.dtype:
# attn_bias = L.cast(attn_bias, enc_input.dtype)
# def to_2d(t_3d):
# t_2d = L.gather_nd(t_3d, pad_idx)
# return t_2d
def to_2d(t_3d):
t_2d = L.gather_nd(t_3d, pad_idx)
return t_2d
# def to_3d(t_2d):
# t_3d = L.scatter_nd(
# pad_idx, t_2d, shape=[d_shape[0], d_shape[1], d_model])
# return t_3d
def to_3d(t_2d):
t_3d = L.scatter_nd(
pad_idx, t_2d, shape=[d_shape[0], d_shape[1], d_model])
return t_3d
enc_input = to_2d(enc_input)
all_hidden = []
......@@ -456,7 +449,7 @@ def graph_encoder(enc_input,
encoder_layer.
"""
#global to_2d, to_3d #, batch, seqlen, dynamic_dim
global to_2d, to_3d #, batch, seqlen, dynamic_dim
d_shape = L.shape(input_mask)
pad_idx = build_pad_idx(input_mask)
attn_bias = build_graph_attn_bias(input_mask, n_head, enc_input.dtype, slot_seqlen)
......@@ -474,14 +467,14 @@ def graph_encoder(enc_input,
# if attn_bias.dtype != enc_input.dtype:
# attn_bias = L.cast(attn_bias, enc_input.dtype)
# def to_2d(t_3d):
# t_2d = L.gather_nd(t_3d, pad_idx)
# return t_2d
def to_2d(t_3d):
t_2d = L.gather_nd(t_3d, pad_idx)
return t_2d
# def to_3d(t_2d):
# t_3d = L.scatter_nd(
# pad_idx, t_2d, shape=[d_shape[0], d_shape[1], d_model])
# return t_3d
def to_3d(t_2d):
t_3d = L.scatter_nd(
pad_idx, t_2d, shape=[d_shape[0], d_shape[1], d_model])
return t_3d
enc_input = to_2d(enc_input)
all_hidden = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册