提交 a7646bb6 编写于 作者: Z zp7 提交者: Yanzhan Yang

fix lod_set,cast cpu op (#1786)

* fix lod_set,cast cpu op

* 1.fix lod_set,cast cpu op
2.fix gpu compile error when set_lod called in infershape
上级 2402029d
......@@ -27,7 +27,9 @@ void IncrementOp<Dtype, T>::InferShape() const {
auto out = this->param_.Out();
PADDLE_MOBILE_ENFORCE(input->numel() == 1, "input's numel should be 1");
out->Resize(input->dims());
#ifdef PADDLE_MOBILE_CPU
out->set_lod(input->lod());
#endif
}
} // namespace operators
......
......@@ -40,20 +40,20 @@ struct CastOutOpFunctor {
}
};
struct CastOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
int output_type_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const int output_type)
: in_(in), out_(out), output_type_(output_type) {}
template <typename InT>
void apply() const {
framework::VisitDataType(framework::ToDataType(output_type_),
CastOutOpFunctor<InT>(in_, out_));
}
};
// struct CastOpFunctor {
// const framework::Tensor* in_;
// framework::Tensor* out_;
// int output_type_;
// CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
// const int output_type)
// : in_(in), out_(out), output_type_(output_type) {}
//
// template <typename InT>
// void apply() const {
// framework::VisitDataType(framework::ToDataType(output_type_),
// CastOutOpFunctor<InT>(in_, out_));
// }
//};
template <>
bool CastKernel<CPU, float>::Init(CastParam<CPU>* param) {
......@@ -64,8 +64,18 @@ template <>
void CastKernel<CPU, float>::Compute(const CastParam<CPU>& param) {
const Tensor* input = param.input_;
Tensor* output = param.output_;
framework::VisitDataType(framework::ToDataType(param.input_type_),
CastOpFunctor(input, output, param.output_type_));
if (input->type() == type_id<float>()) {
framework::VisitDataType(framework::ToDataType(param.output_type_),
CastOutOpFunctor<float>(input, output));
} else if (input->type() == type_id<int64_t>()) {
framework::VisitDataType(framework::ToDataType(param.output_type_),
CastOutOpFunctor<int64_t>(input, output));
} else if (input->type() == type_id<int>()) {
framework::VisitDataType(framework::ToDataType(param.output_type_),
CastOutOpFunctor<int>(input, output));
} else {
PADDLE_MOBILE_ENFORCE(0, "input tpye not support now!")
}
}
} // namespace operators
......
......@@ -23,9 +23,11 @@ template <typename Dtype, typename T>
void LodResetOp<Dtype, T>::InferShape() const {
const auto &input_dims = this->param_.input_x_->dims();
this->param_.output_->Resize(input_dims);
#ifdef PADDLE_MOBILE_CPU
if (this->param_.append) {
this->param_.output_->set_lod(this->param_.input_x_->lod());
}
#endif
}
} // namespace operators
......
......@@ -26,7 +26,9 @@ void OnehotOp<Dtype, T>::InferShape() const {
framework::DDim out_dims(x_dims);
out_dims[out_dims.size() - 1] = depth;
this->param_.output_->Resize(out_dims);
#ifdef PADDLE_MOBILE_CPU
this->param_.output_->set_lod(this->param_.input_->lod());
#endif
}
} // namespace operators
......
......@@ -3285,7 +3285,9 @@ class LodResetParam : public OpParam {
} else {
target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs);
}
append = OpParam::GetAttr<bool>("append", attrs);
if (HasAttr("append", attrs)) {
append = OpParam::GetAttr<bool>("append", attrs);
}
}
public:
......
......@@ -64,10 +64,12 @@ void ReduceProdOp<Dtype, T>::InferShape() const {
}
auto out_dims = framework::make_ddim(dims_vector);
this->param_.Output()->Resize(out_dims);
#ifdef PADDLE_MOBILE_CPU
if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
this->param_.Output()->set_lod(this->param_.Input()->lod());
}
#endif
}
}
......
......@@ -84,9 +84,11 @@ void SliceOp<Dtype, T>::InferShape() const {
}
}
output->Resize(out_dims);
#ifdef PADDLE_MOBILE_CPU
if (axes[0] != 0) {
output->set_lod(input->lod());
}
#endif
}
} // namespace operators
......
......@@ -27,8 +27,10 @@ void TopKOp<DeviceType, T>::InferShape() const {
dims[dims.size() - 1] = k;
this->param_.output_->Resize(dims);
this->param_.indices_->Resize(dims);
#ifdef PADDLE_MOBILE_CPU
this->param_.output_->set_lod(this->param_.input_->lod());
this->param_.indices_->set_lod(this->param_.input_->lod());
#endif
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册