提交 a7646bb6 编写于 作者: Z zp7 提交者: Yanzhan Yang

fix lod_set,cast cpu op (#1786)

* fix lod_set,cast cpu op

* 1.fix lod_set,cast cpu op
2.fix gpu compile error when set_lod called in infershape
上级 2402029d
...@@ -27,7 +27,9 @@ void IncrementOp<Dtype, T>::InferShape() const { ...@@ -27,7 +27,9 @@ void IncrementOp<Dtype, T>::InferShape() const {
auto out = this->param_.Out(); auto out = this->param_.Out();
PADDLE_MOBILE_ENFORCE(input->numel() == 1, "input's numel should be 1"); PADDLE_MOBILE_ENFORCE(input->numel() == 1, "input's numel should be 1");
out->Resize(input->dims()); out->Resize(input->dims());
#ifdef PADDLE_MOBILE_CPU
out->set_lod(input->lod()); out->set_lod(input->lod());
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -40,20 +40,20 @@ struct CastOutOpFunctor { ...@@ -40,20 +40,20 @@ struct CastOutOpFunctor {
} }
}; };
struct CastOpFunctor { // struct CastOpFunctor {
const framework::Tensor* in_; // const framework::Tensor* in_;
framework::Tensor* out_; // framework::Tensor* out_;
int output_type_; // int output_type_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, // CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const int output_type) // const int output_type)
: in_(in), out_(out), output_type_(output_type) {} // : in_(in), out_(out), output_type_(output_type) {}
//
template <typename InT> // template <typename InT>
void apply() const { // void apply() const {
framework::VisitDataType(framework::ToDataType(output_type_), // framework::VisitDataType(framework::ToDataType(output_type_),
CastOutOpFunctor<InT>(in_, out_)); // CastOutOpFunctor<InT>(in_, out_));
} // }
}; //};
template <> template <>
bool CastKernel<CPU, float>::Init(CastParam<CPU>* param) { bool CastKernel<CPU, float>::Init(CastParam<CPU>* param) {
...@@ -64,8 +64,18 @@ template <> ...@@ -64,8 +64,18 @@ template <>
void CastKernel<CPU, float>::Compute(const CastParam<CPU>& param) { void CastKernel<CPU, float>::Compute(const CastParam<CPU>& param) {
const Tensor* input = param.input_; const Tensor* input = param.input_;
Tensor* output = param.output_; Tensor* output = param.output_;
framework::VisitDataType(framework::ToDataType(param.input_type_), if (input->type() == type_id<float>()) {
CastOpFunctor(input, output, param.output_type_)); framework::VisitDataType(framework::ToDataType(param.output_type_),
CastOutOpFunctor<float>(input, output));
} else if (input->type() == type_id<int64_t>()) {
framework::VisitDataType(framework::ToDataType(param.output_type_),
CastOutOpFunctor<int64_t>(input, output));
} else if (input->type() == type_id<int>()) {
framework::VisitDataType(framework::ToDataType(param.output_type_),
CastOutOpFunctor<int>(input, output));
} else {
PADDLE_MOBILE_ENFORCE(0, "input tpye not support now!")
}
} }
} // namespace operators } // namespace operators
......
...@@ -23,9 +23,11 @@ template <typename Dtype, typename T> ...@@ -23,9 +23,11 @@ template <typename Dtype, typename T>
void LodResetOp<Dtype, T>::InferShape() const { void LodResetOp<Dtype, T>::InferShape() const {
const auto &input_dims = this->param_.input_x_->dims(); const auto &input_dims = this->param_.input_x_->dims();
this->param_.output_->Resize(input_dims); this->param_.output_->Resize(input_dims);
#ifdef PADDLE_MOBILE_CPU
if (this->param_.append) { if (this->param_.append) {
this->param_.output_->set_lod(this->param_.input_x_->lod()); this->param_.output_->set_lod(this->param_.input_x_->lod());
} }
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -26,7 +26,9 @@ void OnehotOp<Dtype, T>::InferShape() const { ...@@ -26,7 +26,9 @@ void OnehotOp<Dtype, T>::InferShape() const {
framework::DDim out_dims(x_dims); framework::DDim out_dims(x_dims);
out_dims[out_dims.size() - 1] = depth; out_dims[out_dims.size() - 1] = depth;
this->param_.output_->Resize(out_dims); this->param_.output_->Resize(out_dims);
#ifdef PADDLE_MOBILE_CPU
this->param_.output_->set_lod(this->param_.input_->lod()); this->param_.output_->set_lod(this->param_.input_->lod());
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -3285,7 +3285,9 @@ class LodResetParam : public OpParam { ...@@ -3285,7 +3285,9 @@ class LodResetParam : public OpParam {
} else { } else {
target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs); target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs);
} }
append = OpParam::GetAttr<bool>("append", attrs); if (HasAttr("append", attrs)) {
append = OpParam::GetAttr<bool>("append", attrs);
}
} }
public: public:
......
...@@ -64,10 +64,12 @@ void ReduceProdOp<Dtype, T>::InferShape() const { ...@@ -64,10 +64,12 @@ void ReduceProdOp<Dtype, T>::InferShape() const {
} }
auto out_dims = framework::make_ddim(dims_vector); auto out_dims = framework::make_ddim(dims_vector);
this->param_.Output()->Resize(out_dims); this->param_.Output()->Resize(out_dims);
#ifdef PADDLE_MOBILE_CPU
if (dims[0] != 0) { if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim. // Only pass LoD when not reducing on the first dim.
this->param_.Output()->set_lod(this->param_.Input()->lod()); this->param_.Output()->set_lod(this->param_.Input()->lod());
} }
#endif
} }
} }
......
...@@ -84,9 +84,11 @@ void SliceOp<Dtype, T>::InferShape() const { ...@@ -84,9 +84,11 @@ void SliceOp<Dtype, T>::InferShape() const {
} }
} }
output->Resize(out_dims); output->Resize(out_dims);
#ifdef PADDLE_MOBILE_CPU
if (axes[0] != 0) { if (axes[0] != 0) {
output->set_lod(input->lod()); output->set_lod(input->lod());
} }
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -27,8 +27,10 @@ void TopKOp<DeviceType, T>::InferShape() const { ...@@ -27,8 +27,10 @@ void TopKOp<DeviceType, T>::InferShape() const {
dims[dims.size() - 1] = k; dims[dims.size() - 1] = k;
this->param_.output_->Resize(dims); this->param_.output_->Resize(dims);
this->param_.indices_->Resize(dims); this->param_.indices_->Resize(dims);
#ifdef PADDLE_MOBILE_CPU
this->param_.output_->set_lod(this->param_.input_->lod()); this->param_.output_->set_lod(this->param_.input_->lod());
this->param_.indices_->set_lod(this->param_.input_->lod()); this->param_.indices_->set_lod(this->param_.input_->lod());
#endif
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册