From a30ba1a6fa40d80af0df139ae8bf745889aba34f Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 6 Dec 2019 10:54:07 +0800 Subject: [PATCH] fix fill_constant bug and add int64->int32 cast test=develop (#2566) - fix fill_constant bug. - cast op support int64_t->int32_t --- lite/kernels/arm/cast_compute.cc | 6 ++++++ lite/operators/fill_constant_op.cc | 5 ++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index bc274ea224..1fef52bcb7 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -56,6 +56,12 @@ void CastCompute::Run() { float* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 3 && param.out_dtype == 2) { + const int64_t* x_data_begin = param.X->data(); + const int64_t* x_data_end = x_data_begin + param.X->numel(); + int32_t* out_data = param.Out->mutable_data(); + std::transform( + x_data_begin, x_data_end, out_data, TransOp); } else { LOG(FATAL) << "other has not been implemented"; } diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index acf9701cbd..bd4b483e9e 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -51,9 +51,8 @@ class FillConstantOp : public OpLite { param_.shape_tensor_list = {}; std::vector input_arg_names = opdesc.InputArgumentNames(); - if (std::find(input_arg_names.begin(), - input_arg_names.end(), - "ShapeTensor") != input_arg_names.end()) { + if (opdesc.HasInput("ShapeTensor") && + !opdesc.Input("ShapeTensor").empty()) { auto args = opdesc.Input("ShapeTensor"); auto* var = scope->FindVar(args.front()); param_.shape_tensor = var->GetMutable(); -- GitLab