diff --git a/mace/ops/lstmcell.h b/mace/ops/lstmcell.h index a403237947bf286357a08f5c5918e7a95643b483..c632cd3dee39c20b9e1cc4420f679c01bb1c0e63 100644 --- a/mace/ops/lstmcell.h +++ b/mace/ops/lstmcell.h @@ -29,7 +29,7 @@ class LSTMCellOp : public Operator { LSTMCellOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), functor_(static_cast( - OperatorBase::GetOptionalArg("value", 0.0))) {} + OperatorBase::GetOptionalArg("scalar_input", 0.0))) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/ops/lstmcell_test.cc b/mace/ops/lstmcell_test.cc index 418661937d1ef684d0bade0cdc3fb73b23cd17c2..109096c5ffec45f6d022108017c55b5c4f609799 100644 --- a/mace/ops/lstmcell_test.cc +++ b/mace/ops/lstmcell_test.cc @@ -84,7 +84,7 @@ void LSTMCellCPU(OpsTestNet *net, OpDefBuilder("Eltwise", "ForgetAdd") .Input("SplitOutput2") - .AddFloatArg("value", forget_add_name) + .AddFloatArg("scalar_input", forget_add_name) .AddIntArg("T", DataTypeToEnum::v()) .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) .Output("ForgetAdd") @@ -176,7 +176,7 @@ void TestLSTMCell(const uint32_t &batch, .Input("WeightImage") .Input("BiasImage") .Input("PreCellImage") - .AddFloatArg("forget_add", forget_add) + .AddFloatArg("scalar_input", forget_add) .Output("CellImage") .Output("OutputImage") .Finalize(net.NewOperatorDef());