diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index bc274ea22485e84a1cc9145e62fc967f2847c5dd..1fef52bcb77b7c3efdcd848ee63f8ec46c16d6f8 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -56,6 +56,12 @@ void CastCompute::Run() { float* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 3 && param.out_dtype == 2) { + const int64_t* x_data_begin = param.X->data(); + const int64_t* x_data_end = x_data_begin + param.X->numel(); + int32_t* out_data = param.Out->mutable_data(); + std::transform( + x_data_begin, x_data_end, out_data, TransOp); } else { LOG(FATAL) << "other has not been implemented"; } diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index acf9701cbd750e83ba51f25c66064c2dd7781db6..bd4b483e9ed20f89bf2d072ca21bdc24a0e82256 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -51,9 +51,8 @@ class FillConstantOp : public OpLite { param_.shape_tensor_list = {}; std::vector input_arg_names = opdesc.InputArgumentNames(); - if (std::find(input_arg_names.begin(), - input_arg_names.end(), - "ShapeTensor") != input_arg_names.end()) { + if (opdesc.HasInput("ShapeTensor") && + !opdesc.Input("ShapeTensor").empty()) { auto args = opdesc.Input("ShapeTensor"); auto* var = scope->FindVar(args.front()); param_.shape_tensor = var->GetMutable();