提交 772fa7fc 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #277 from Eclipsess/develop

fix #276 add reshape op and test
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "common/enforce.h"
#include "common/type_define.h"
#include "common/types.h"
#include "common/variant.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/kernel/reshape_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims();
auto *out = param.Out();
framework::DDim out_dims = out->dims();
const auto *input_shape = param.InputShape();
if (input_shape) {
auto *shape_data = input_shape->data<int>();
framework::Tensor cpu_shape_tensor;
auto shape =
std::vector<int>(shape_data, shape_data + input_shape->numel());
out_dims = ValidateShape(shape, input_x->dims());
}
bool inplace = param.Inplace();
out->Resize(out_dims);
if (!inplace) {
out->mutable_data<float>();
framework::TensorCopy(*input_x, out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*input_x);
out->Resize(out_dims);
}
}
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
inline framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim& in_dims) {
const int64_t in_size = framework::product(in_dims);
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_MOBILE_ENFORCE(
unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_MOBILE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_MOBILE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] = (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size,
"Invalid shape is given.");
} else {
PADDLE_MOBILE_ENFORCE(capacity == in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public:
void Compute(const ReshapeParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -99,6 +99,11 @@ class OpParam : PaddleMobileObject {
return GetVarValue<T>("Scores", inputs, scope);
}
template <typename T>
static T *InputShapeFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Shape", inputs, scope);
}
template <typename T>
static vector<T *> InputMultiFrom(const VariableNameMap &inputs,
const Scope &scope) {
......@@ -636,5 +641,33 @@ class TransposeParam : public OpParam {
vector<int> axis_;
};
class ReshapeParam : public OpParam {
public:
ReshapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<Tensor>(inputs, scope);
input_shape_ = InputShapeFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
shape_ = GetAttr<vector<int>>("shape", attrs);
inplace_ = GetAttr<bool>("inplace", attrs);
}
const Tensor *InputX() const { return input_x_; }
const Tensor *InputShape() const { return input_shape_; }
Tensor *Out() const { return out_; }
const vector<int> &Shape() const { return shape_; }
const bool &Inplace() const { return inplace_; }
private:
Tensor *input_x_;
Tensor *input_shape_;
Tensor *out_;
vector<int> shape_;
bool inplace_;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/reshape_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ReshapeOp<Dtype, T>::InferShape() const {
/// todo: add InputShape() detection.
auto &shape = param_.Shape();
auto input_x_dims = param_.InputX()->dims();
auto out_dims = ValidateShape(shape, input_x_dims);
param_.Out()->Resize(out_dims);
}
template class ReshapeOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
USE_OP(reshape);
REGISTER_OPERATOR(reshape, ops::ReshapeOp);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/reshape_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class ReshapeOp : public framework::OperatorWithKernel<DeviceType> {
public:
ReshapeOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::ReshapeKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
ReshapeParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -42,6 +42,10 @@ target_link_libraries(test-transpose-op paddle-mobile)
ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h)
target_link_libraries(test-multiclassnms-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h)
target_link_libraries(test-reshape-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "io.h"
#include "operators/conv_op.h"
#include "operators/pool_op.h"
#include "operators/reshape_op.h"
#include "operators/softmax_op.h"
#include "operators/transpose_op.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../executor_for_test.h"
#include "../test_helper.h"
#include "./io.h"
int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd"));
if (program.originProgram == nullptr) {
DLOG << "program read file";
}
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::ReshapeOp<paddle_mobile::CPU, float>>
executor(program, "reshape");
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {2, 3, 3, 2}, static_cast<float>(0),
static_cast<float>(1));
auto input_ptr = input.data<float>();
auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2});
auto output =
executor.predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim);
auto *output_ptr = output->data<float>();
DLOG << "input : ";
for (int j = 0; j < input.numel(); ++j) {
DLOG << " index " << j << " : " << input_ptr[j];
}
DLOG << "output : ";
for (int j = 0; j < output->numel(); ++j) {
DLOG << " index " << j << " : " << output_ptr[j];
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册