提交 5d3acea2 编写于 作者: E eclipsess

code style

上级 cd63d8a2
...@@ -17,10 +17,10 @@ limitations under the License. */ ...@@ -17,10 +17,10 @@ limitations under the License. */
#include "operators/kernel/reshape_kernel.h" #include "operators/kernel/reshape_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const { void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims(); const auto &input_x_dims = input_x->dims();
auto *out = param.Out(); auto *out = param.Out();
...@@ -39,13 +39,13 @@ namespace paddle_mobile { ...@@ -39,13 +39,13 @@ namespace paddle_mobile {
out->Resize(out_dims); out->Resize(out_dims);
if (!inplace) { if (!inplace) {
out->mutable_data<float>(); out->mutable_data<float>();
framework::TensorCopy(*input_x,out); framework::TensorCopy(*input_x, out);
out->Resize(out_dims); out->Resize(out_dims);
} else { } else {
out->ShareDataWith(*input_x); out->ShareDataWith(*input_x);
out->Resize(out_dims); out->Resize(out_dims);
} }
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -20,11 +20,10 @@ limitations under the License. */ ...@@ -20,11 +20,10 @@ limitations under the License. */
#pragma once; #pragma once;
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
inline framework::DDim ValidateShape(const std::vector<int> shape,
inline framework::DDim ValidateShape(const std::vector<int> shape, const framework::DDim& in_dims) {
const framework::DDim &in_dims) {
const int64_t in_size = framework::product(in_dims); const int64_t in_size = framework::product(in_dims);
// only one dimension can be set to -1, whose size will be automatically // only one dimension can be set to -1, whose size will be automatically
// infered. // infered.
...@@ -53,8 +52,7 @@ namespace paddle_mobile { ...@@ -53,8 +52,7 @@ namespace paddle_mobile {
} }
capacity *= (shape[i] ? shape[i] : in_dims[i]); capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] = output_shape[i] = (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
} }
if (unk_dim_idx != -1) { if (unk_dim_idx != -1) {
...@@ -62,16 +60,15 @@ namespace paddle_mobile { ...@@ -62,16 +60,15 @@ namespace paddle_mobile {
PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size, PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size,
"Invalid shape is given."); "Invalid shape is given.");
} else { } else {
PADDLE_MOBILE_ENFORCE(capacity==in_size, "Invalid shape is given."); PADDLE_MOBILE_ENFORCE(capacity == in_size, "Invalid shape is given.");
} }
return framework::make_ddim(output_shape); return framework::make_ddim(output_shape);
} }
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ReshapeKernel class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
: public framework::OpKernelBase<DeviceType, ReshapeParam> {
public: public:
void Compute(const ReshapeParam& param) const; void Compute(const ReshapeParam& param) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -15,18 +15,18 @@ limitations under the License. */ ...@@ -15,18 +15,18 @@ limitations under the License. */
#include "operators/reshape_op.h" #include "operators/reshape_op.h"
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename Dtype, typename T> template <typename Dtype, typename T>
void ReshapeOp<Dtype, T>::InferShape() const { void ReshapeOp<Dtype, T>::InferShape() const {
/// todo: add InputShape() detection. /// todo: add InputShape() detection.
auto &shape = param_.Shape(); auto &shape = param_.Shape();
auto input_x_dims = param_.InputX()->dims(); auto input_x_dims = param_.InputX()->dims();
auto out_dims = ValidateShape(shape, input_x_dims); auto out_dims = ValidateShape(shape, input_x_dims);
param_.Out()->Resize(out_dims); param_.Out()->Resize(out_dims);
} }
template class ReshapeOp<CPU, float>; template class ReshapeOp<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
......
...@@ -21,16 +21,15 @@ limitations under the License. */ ...@@ -21,16 +21,15 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using paddle_mobile::framework::Tensor; using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ReshapeOp : public framework::OperatorWithKernel<DeviceType> { class ReshapeOp : public framework::OperatorWithKernel<DeviceType> {
public: public:
ReshapeOp(const std::string &type, const VariableNameMap &inputs, ReshapeOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs, const framework::AttributeMap attrs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs, : framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope), scope),
...@@ -46,7 +45,7 @@ namespace paddle_mobile { ...@@ -46,7 +45,7 @@ namespace paddle_mobile {
protected: protected:
ReshapeParam param_; ReshapeParam param_;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,9 +21,9 @@ limitations under the License. */ ...@@ -21,9 +21,9 @@ limitations under the License. */
#include "io.h" #include "io.h"
#include "operators/conv_op.h" #include "operators/conv_op.h"
#include "operators/pool_op.h" #include "operators/pool_op.h"
#include "operators/reshape_op.h"
#include "operators/softmax_op.h" #include "operators/softmax_op.h"
#include "operators/transpose_op.h" #include "operators/transpose_op.h"
#include "operators/reshape_op.h"
using paddle_mobile::Executor; using paddle_mobile::Executor;
using paddle_mobile::framework::BlockDesc; using paddle_mobile::framework::BlockDesc;
......
...@@ -22,8 +22,8 @@ int main() { ...@@ -22,8 +22,8 @@ int main() {
if (program.originProgram == nullptr) { if (program.originProgram == nullptr) {
DLOG << "program read file"; DLOG << "program read file";
} }
Executor4Test<paddle_mobile::CPU, paddle_mobile::operators::ReshapeOp< Executor4Test<paddle_mobile::CPU,
paddle_mobile::CPU, float>> paddle_mobile::operators::ReshapeOp<paddle_mobile::CPU, float>>
executor(program, "reshape"); executor(program, "reshape");
paddle_mobile::framework::Tensor input; paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {2, 3, 3, 2}, static_cast<float>(0), SetupTensor<float>(&input, {2, 3, 3, 2}, static_cast<float>(0),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册