提交 2cd8c947 编写于 作者: Y Yao,kun

bugfix

上级 0186dee5
...@@ -19,8 +19,8 @@ namespace operators { ...@@ -19,8 +19,8 @@ namespace operators {
template <typename Dtype, typename T> template <typename Dtype, typename T>
void DropoutOp<Dtype, T>::InferShape() const { void DropoutOp<Dtype, T>::InferShape() const {
auto input_dims = param_.InputX()->dims(); auto input_dims = thsi->param_.InputX()->dims();
param_.Out()->Resize(input_dims); this->param_.Out()->Resize(input_dims);
} }
template class DropoutOp<CPU, float>; template class DropoutOp<CPU, float>;
} // namespace operators } // namespace operators
...@@ -28,7 +28,7 @@ template class DropoutOp<CPU, float>; ...@@ -28,7 +28,7 @@ template class DropoutOp<CPU, float>;
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(Dropout); USE_OP_CPU(dropout);
REGISTER_OPERATOR_CPU(dropout, ops::DropoutOp); REGISTER_OPERATOR_CPU(dropout, ops::DropoutOp);
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
......
...@@ -37,11 +37,10 @@ class DropoutOp ...@@ -37,11 +37,10 @@ class DropoutOp
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, DropoutParam, : framework::OperatorWithKernel<DeviceType, DropoutParam,
operators::DropoutKernel<DeviceType, T>>( operators::DropoutKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope), type, inputs, outputs, attrs, scope) {}
param_(inputs, outputs, attrs, *scope) {}
using framework::OperatorWithKernel<DeviceType, DropoutParam, //using framework::OperatorWithKernel<DeviceType, DropoutParam,
operators::DropoutKernel<DeviceType, T>>; // operators::DropoutKernel<DeviceType, T>>;
void InferShape() const override; void InferShape() const override;
protected: protected:
......
...@@ -28,13 +28,13 @@ int Im2SequenceOutputSize(int input_size, int kernel, int padding_1, ...@@ -28,13 +28,13 @@ int Im2SequenceOutputSize(int input_size, int kernel, int padding_1,
template <typename Dtype, typename T> template <typename Dtype, typename T>
void Im2SequenceOp<Dtype, T>::InferShape() const { void Im2SequenceOp<Dtype, T>::InferShape() const {
auto in_x_dims = param_.Input()->dims(); auto in_x_dims = this->param_.Input()->dims();
const std::vector<int> &kernels = param_.Kernels(); const std::vector<int> &kernels = this->param_.Kernels();
const std::vector<int> &strides = param_.Strides(); const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = param_.Paddings(); std::vector<int> paddings = this->param_.Paddings();
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
...@@ -44,7 +44,7 @@ void Im2SequenceOp<Dtype, T>::InferShape() const { ...@@ -44,7 +44,7 @@ void Im2SequenceOp<Dtype, T>::InferShape() const {
} }
framework::DDim ddim = framework::make_ddim(output_shape); framework::DDim ddim = framework::make_ddim(output_shape);
param_.Output()->Resize(ddim); this->param_.Output()->Resize(ddim);
} }
template class Im2SequenceOp<CPU, float>; template class Im2SequenceOp<CPU, float>;
......
...@@ -37,12 +37,11 @@ class Im2SequenceOp : public framework::OperatorWithKernel< ...@@ -37,12 +37,11 @@ class Im2SequenceOp : public framework::OperatorWithKernel<
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, Im2SequenceParam, DeviceType, Im2SequenceParam,
operators::Im2SequenceKernel<DeviceType, T>>(type, inputs, outputs, operators::Im2SequenceKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope), attrs, scope) {}
param_(inputs, outputs, attrs, *scope) {}
using framework::OperatorWithKernel< //using framework::OperatorWithKernel<
DeviceType, Im2SequenceParam, // DeviceType, Im2SequenceParam,
operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel; // operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
private: private:
......
...@@ -23,7 +23,7 @@ namespace paddle_mobile { ...@@ -23,7 +23,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool DropoutKernel<CPU, float>::Init(const DropoutParam &para) const { bool DropoutKernel<CPU, float>::Init(DropoutParam *para) {
return true; return true;
} }
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool Im2SequenceKernel<CPU, float>::Init(const Im2SequenceParam &para) const { bool Im2SequenceKernel<CPU, float>::Init(Im2SequenceParam *para) {
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ template <typename DeviceType, typename T> ...@@ -26,7 +26,7 @@ template <typename DeviceType, typename T>
class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> { class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> {
public: public:
void Compute(const DropoutParam& param) const; void Compute(const DropoutParam& param) const;
bool Init(const DropoutParam& para) const; bool Init(DropoutParam *para);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -32,7 +32,7 @@ class Im2SequenceKernel ...@@ -32,7 +32,7 @@ class Im2SequenceKernel
: public framework::OpKernelBase<DeviceType, Im2SequenceParam> { : public framework::OpKernelBase<DeviceType, Im2SequenceParam> {
public: public:
void Compute(const Im2SequenceParam &param) const; void Compute(const Im2SequenceParam &param) const;
bool Init(const Im2SequenceParam &para) const; bool Init(Im2SequenceParam* para);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -156,6 +156,6 @@ endif() ...@@ -156,6 +156,6 @@ endif()
if (DROPOUT_OP) if (DROPOUT_OP)
add_definitions(-DDROPOUT_OP) add_definitions(-DDROPOUT_OP)
endif() endif()
if (IM2SQUENCE_OP) if (IM2SEQUENCE_OP)
add_definitions(-DIM2SQUENCE_OP) add_definitions(-DIM2SEQUENCE_OP)
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册