提交 a30ba1a6 编写于 作者: W Wilber 提交者: GitHub

fix fill_constant bug and add int64->int32 cast test=develop (#2566)

- fix fill_constant bug.
- cast op support int64_t->int32_t
上级 da043b14
......@@ -56,6 +56,12 @@ void CastCompute::Run() {
float* out_data = param.Out->mutable_data<float>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 2) {
const int64_t* x_data_begin = param.X->data<int64_t>();
const int64_t* x_data_end = x_data_begin + param.X->numel();
int32_t* out_data = param.Out->mutable_data<int32_t>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<int64_t, int32_t>);
} else {
LOG(FATAL) << "other has not been implemented";
}
......
......@@ -51,9 +51,8 @@ class FillConstantOp : public OpLite {
param_.shape_tensor_list = {};
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"ShapeTensor") != input_arg_names.end()) {
if (opdesc.HasInput("ShapeTensor") &&
!opdesc.Input("ShapeTensor").empty()) {
auto args = opdesc.Input("ShapeTensor");
auto* var = scope->FindVar(args.front());
param_.shape_tensor = var->GetMutable<lite::Tensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册