提交 77360ba7 编写于 作者: J jiangqiu

edit dropout

上级 f5fd8b20
......@@ -38,11 +38,12 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
// if(op_info->HasAttr("is_test")){
// auto is_test = op_info->GetAttr<bool>("is_test");
// CHECK(is_test != true); // The dropout op has no training
// implementation, only inference implementation
// }
// is_test is true by default
// if(op_info->HasAttr("is_test")){
// auto is_test = op_info->GetAttr<bool>("is_test");
// CHECK(is_test != true);
// }
auto dropout_implementation =
op_info->GetAttr<std::string>("dropout_implementation");
auto dropout_prob = op_info->GetAttr<float>("dropout_prob");
......@@ -53,12 +54,12 @@ int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
float beta = 0.;
std::vector<int64_t> shape = {1, 1, 1, 1};
std::string prefix = string_format("_%p", op);
std::string alpha_var_name = string_format("dropout_alpha_%p", op);
std::string beta_var_name = string_format("dropout_beta_%p", op);
auto alpha_tensor = graph->AddNode(
"Alpha" + prefix, shape, CNML_CONST, CNML_NHWC, graph->FPType());
alpha_var_name, shape, CNML_CONST, CNML_NHWC, graph->FPType());
auto beta_tensor = graph->AddNode(
"Beta" + prefix, shape, CNML_CONST, CNML_NHWC, graph->FPType());
beta_var_name, shape, CNML_CONST, CNML_NHWC, graph->FPType());
graph->BindConstRawData("Alpha" + prefix, &alpha, 1);
graph->BindConstRawData("Beta" + prefix, &beta, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册