diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 919e9c603edff4383f086ac795c3dff4ed856c4f..f17a83110b8c31ccd30eeed2911430b76680160d 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -72,6 +72,12 @@ void CastCompute::Run() { const int64_t* x_data_end = x_data_begin + param.X->numel(); float* out_data = param.Out->mutable_data(); std::transform(x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 2 && param.out_dtype == 3) { // INT32 -> INT64 + const int32_t* x_data_begin = param.X->data(); + const int32_t* x_data_end = x_data_begin + param.X->numel(); + int64_t* out_data = param.Out->mutable_data(); + std::transform( + x_data_begin, x_data_end, out_data, TransOp); } else { LOG(FATAL) << "other has not been implemented transform with dtype" << param.in_dtype << " X, dtype" << param.out_dtype << " Out";