提交 1e3d867e 编写于 作者: - --get 提交者: jackzhang235

(bugfix):dropout little mistake

上级 04199b2a
...@@ -61,8 +61,8 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,8 +61,8 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto beta_tensor = graph->AddNode( auto beta_tensor = graph->AddNode(
beta_var_name, shape, CNML_CONST, CNML_NHWC, graph->FPType()); beta_var_name, shape, CNML_CONST, CNML_NHWC, graph->FPType());
graph->BindConstRawData("Alpha" + prefix, &alpha, 1); graph->BindConstRawData(alpha_var_name, &alpha, 1);
graph->BindConstRawData("Beta" + prefix, &beta, 1); graph->BindConstRawData(beta_var_name, &beta, 1);
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t scale_op; cnmlBaseOp_t scale_op;
......
...@@ -85,9 +85,9 @@ void test_dropout(int bs, ...@@ -85,9 +85,9 @@ void test_dropout(int bs,
opdesc.SetAttr("seed", seed); opdesc.SetAttr("seed", seed);
opdesc.SetAttr("dropout_implementation", dropout_implementation); opdesc.SetAttr("dropout_implementation", dropout_implementation);
opdesc.SetAttr("dropout_prob", dropout_prob); opdesc.SetAttr("dropout_prob", dropout_prob);
VLOG(6) << "mask: " << mask->dims()[0] << std::endl;
// create and convert op to MLU model, then run it on MLU // create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::DropoutOp>(opdesc, &scope); auto op = CreateOp<operators::DropoutOp>(opdesc, &scope);
VLOG(6) << "mask: " << mask << std::endl;
dropout_ref(op); dropout_ref(op);
out_ref->CopyDataFrom(*out); out_ref->CopyDataFrom(*out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册