提交 5e0b323d 编写于 作者: J jiangqiu 提交者: MaxwellDing

add dropout:"add dropout to develop-move-mlu-pass-forward branch for run senet101"

上级 fffb2fe6
......@@ -33,10 +33,15 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Create act node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto mask_var_name = op_info->Output("Mask").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
auto mask = scope->FindVar(mask_var_name)->GetMutable<Tensor>();
auto mask_dims = mask->dims().Vectorize();
auto mask_tensor = graph->AddNode(
mask_var_name, mask_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
// is_test is true by default
// if(op_info->HasAttr("is_test")){
......
......@@ -85,9 +85,9 @@ void test_dropout(int bs,
opdesc.SetAttr("seed", seed);
opdesc.SetAttr("dropout_implementation", dropout_implementation);
opdesc.SetAttr("dropout_prob", dropout_prob);
VLOG(6) << "mask: " << mask->dims()[0] << std::endl;
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::DropoutOp>(opdesc, &scope);
VLOG(6) << "mask: " << mask << std::endl;
dropout_ref(op);
out_ref->CopyDataFrom(*out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册