提交 5d3acea2 编写于 作者: E eclipsess

code style

上级 cd63d8a2
...@@ -17,35 +17,35 @@ limitations under the License. */ ...@@ -17,35 +17,35 @@ limitations under the License. */
#include "operators/kernel/reshape_kernel.h" #include "operators/kernel/reshape_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const { void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims(); const auto &input_x_dims = input_x->dims();
auto *out = param.Out(); auto *out = param.Out();
framework::DDim out_dims = out->dims(); framework::DDim out_dims = out->dims();
const auto *input_shape = param.InputShape(); const auto *input_shape = param.InputShape();
if (input_shape) { if (input_shape) {
auto *shape_data = input_shape->data<int>(); auto *shape_data = input_shape->data<int>();
framework::Tensor cpu_shape_tensor; framework::Tensor cpu_shape_tensor;
auto shape = auto shape =
std::vector<int>(shape_data, shape_data + input_shape->numel()); std::vector<int>(shape_data, shape_data + input_shape->numel());
out_dims = ValidateShape(shape, input_x->dims()); out_dims = ValidateShape(shape, input_x->dims());
} }
bool inplace = param.Inplace(); bool inplace = param.Inplace();
out->Resize(out_dims); out->Resize(out_dims);
if (!inplace) { if (!inplace) {
out->mutable_data<float>(); out->mutable_data<float>();
framework::TensorCopy(*input_x,out); framework::TensorCopy(*input_x, out);
out->Resize(out_dims); out->Resize(out_dims);
} else { } else {
out->ShareDataWith(*input_x); out->ShareDataWith(*input_x);
out->Resize(out_dims); out->Resize(out_dims);
} }
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -20,58 +20,55 @@ limitations under the License. */ ...@@ -20,58 +20,55 @@ limitations under the License. */
#pragma once; #pragma once;
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
inline framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim& in_dims) {
const int64_t in_size = framework::product(in_dims);
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
inline framework::DDim ValidateShape(const std::vector<int> shape, std::vector<int64_t> output_shape(shape.size(), 0);
const framework::DDim &in_dims) { int64_t capacity = 1;
const int64_t in_size = framework::product(in_dims); int unk_dim_idx = -1;
// only one dimension can be set to -1, whose size will be automatically for (size_t i = 0; i < shape.size(); ++i) {
// infered. if (shape[i] == unk_dim_val) {
const int64_t unk_dim_val = -1; PADDLE_MOBILE_ENFORCE(
const int64_t copy_dim_val = 0; unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_MOBILE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_MOBILE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
std::vector<int64_t> output_shape(shape.size(), 0); capacity *= (shape[i] ? shape[i] : in_dims[i]);
int64_t capacity = 1; output_shape[i] = (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
int unk_dim_idx = -1; }
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_MOBILE_ENFORCE(
unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_MOBILE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_MOBILE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
capacity *= (shape[i] ? shape[i] : in_dims[i]); if (unk_dim_idx != -1) {
output_shape[i] = output_shape[unk_dim_idx] = -in_size / capacity;
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]); PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size,
} "Invalid shape is given.");
} else {
PADDLE_MOBILE_ENFORCE(capacity == in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
if (unk_dim_idx != -1) { template <typename DeviceType, typename T>
output_shape[unk_dim_idx] = -in_size / capacity; class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size, public:
"Invalid shape is given."); void Compute(const ReshapeParam& param) const;
} else { };
PADDLE_MOBILE_ENFORCE(capacity==in_size, "Invalid shape is given."); } // namespace operators
}
return framework::make_ddim(output_shape);
}
template <typename DeviceType, typename T>
class ReshapeKernel
: public framework::OpKernelBase<DeviceType, ReshapeParam> {
public:
void Compute(const ReshapeParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -15,18 +15,18 @@ limitations under the License. */ ...@@ -15,18 +15,18 @@ limitations under the License. */
#include "operators/reshape_op.h" #include "operators/reshape_op.h"
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename Dtype, typename T> template <typename Dtype, typename T>
void ReshapeOp<Dtype, T>::InferShape() const { void ReshapeOp<Dtype, T>::InferShape() const {
/// todo: add InputShape() detection. /// todo: add InputShape() detection.
auto &shape = param_.Shape(); auto &shape = param_.Shape();
auto input_x_dims = param_.InputX()->dims(); auto input_x_dims = param_.InputX()->dims();
auto out_dims = ValidateShape(shape, input_x_dims); auto out_dims = ValidateShape(shape, input_x_dims);
param_.Out()->Resize(out_dims); param_.Out()->Resize(out_dims);
} }
template class ReshapeOp<CPU, float>; template class ReshapeOp<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
......
...@@ -21,32 +21,31 @@ limitations under the License. */ ...@@ -21,32 +21,31 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using paddle_mobile::framework::Tensor; using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ReshapeOp : public framework::OperatorWithKernel<DeviceType> { class ReshapeOp : public framework::OperatorWithKernel<DeviceType> {
public: public:
ReshapeOp(const std::string &type, const VariableNameMap &inputs, ReshapeOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs, const framework::AttributeMap attrs,
const framework::AttributeMap attrs, std::shared_ptr<framework::Scope> scope)
std::shared_ptr<framework::Scope> scope) : framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs, scope),
scope), param_(inputs, outputs, attrs, *scope) {}
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
void Run() const { operators::ReshapeKernel<DeviceType, T> kernel;
operators::ReshapeKernel<DeviceType, T> kernel; kernel.Compute(param_);
kernel.Compute(param_); }
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel; void InferShape() const override;
void InferShape() const override;
protected:
protected: ReshapeParam param_;
ReshapeParam param_; };
};
} // namespace operators
} // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,9 +21,9 @@ limitations under the License. */ ...@@ -21,9 +21,9 @@ limitations under the License. */
#include "io.h" #include "io.h"
#include "operators/conv_op.h" #include "operators/conv_op.h"
#include "operators/pool_op.h" #include "operators/pool_op.h"
#include "operators/reshape_op.h"
#include "operators/softmax_op.h" #include "operators/softmax_op.h"
#include "operators/transpose_op.h" #include "operators/transpose_op.h"
#include "operators/reshape_op.h"
using paddle_mobile::Executor; using paddle_mobile::Executor;
using paddle_mobile::framework::BlockDesc; using paddle_mobile::framework::BlockDesc;
......
...@@ -17,32 +17,32 @@ limitations under the License. */ ...@@ -17,32 +17,32 @@ limitations under the License. */
#include "./io.h" #include "./io.h"
int main() { int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader; paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); auto program = loader.Load(std::string("../../test/models/mobilenet+ssd"));
if (program.originProgram == nullptr) { if (program.originProgram == nullptr) {
DLOG << "program read file"; DLOG << "program read file";
} }
Executor4Test<paddle_mobile::CPU, paddle_mobile::operators::ReshapeOp< Executor4Test<paddle_mobile::CPU,
paddle_mobile::CPU, float>> paddle_mobile::operators::ReshapeOp<paddle_mobile::CPU, float>>
executor(program, "reshape"); executor(program, "reshape");
paddle_mobile::framework::Tensor input; paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {2, 3, 3, 2}, static_cast<float>(0), SetupTensor<float>(&input, {2, 3, 3, 2}, static_cast<float>(0),
static_cast<float>(1)); static_cast<float>(1));
auto input_ptr = input.data<float>(); auto input_ptr = input.data<float>();
auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2}); auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2});
auto output = auto output =
executor.predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim); executor.predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim);
auto *output_ptr = output->data<float>(); auto *output_ptr = output->data<float>();
DLOG << "input : "; DLOG << "input : ";
for (int j = 0; j < input.numel(); ++j) { for (int j = 0; j < input.numel(); ++j) {
DLOG << " index " << j << " : " << input_ptr[j]; DLOG << " index " << j << " : " << input_ptr[j];
} }
DLOG << "output : "; DLOG << "output : ";
for (int j = 0; j < output->numel(); ++j) { for (int j = 0; j < output->numel(); ++j) {
DLOG << " index " << j << " : " << output_ptr[j]; DLOG << " index " << j << " : " << output_ptr[j];
} }
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册