提交 6f6a8ae9 编写于 作者: Y yangzhenzhang

embedding lookup auto parallel

上级 28755b2f
......@@ -264,7 +264,8 @@ bool IsSplittableOperator(const std::string &op_name) {
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
EMBEDDING_LOOKUP};
// clang-format on
auto iter = splittable_op.find(op_name);
......
......@@ -115,6 +115,13 @@ def test_auto_parallel_error():
compile_net(net)
def test_auto_parallel():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
net = Net(split_string="fake")
compile_net(net)
def test_axis_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册