提交 68d24507 编写于 作者: H hjchen2

add pad2d op, and fix bugs

上级 89fc6315
......@@ -112,6 +112,8 @@ const char *G_OP_TYPE_GENERATE_PROPOSALS = "generate_proposals";
const char *G_OP_TYPE_PSROI_POOL = "psroi_pool";
const char *G_OP_TYPE_ROI_PERSPECTIVE = "roi_perspective_transform";
const char *G_OP_TYPE_PAD2D = "pad2d";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = {
......@@ -210,5 +212,6 @@ std::unordered_map<
{{"Scores", "BboxDeltas", "ImInfo", "Anchors", "Variances"},
{"RpnRois", "RpnRoiProbs"}}},
{G_OP_TYPE_PSROI_POOL, {{"X", "ROIs"}, {"Out"}}},
{G_OP_TYPE_ROI_PERSPECTIVE, {{"X", "ROIs"}, {"Out"}}}};
{G_OP_TYPE_ROI_PERSPECTIVE, {{"X", "ROIs"}, {"Out"}}},
{G_OP_TYPE_PAD2D, {{"X"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -200,6 +200,8 @@ extern const char *G_OP_TYPE_GENERATE_PROPOSALS;
extern const char *G_OP_TYPE_PSROI_POOL;
extern const char *G_OP_TYPE_ROI_PERSPECTIVE;
extern const char *G_OP_TYPE_PAD2D;
extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key;
......
......@@ -378,18 +378,14 @@ std::vector<T> Executor<Device, T>::Predict(const std::vector<T> &input,
template <typename Device, typename T>
void Executor<Device, T>::SetInput(const Tensor &input,
const std::string &var_name) {
framework::LoDTensor *target = nullptr;
int index = 0;
if (feed_indices_.find(var_name) != feed_indices_.end()) {
int index = feed_indices_.find(var_name)->second;
auto *feed_var = program_.scope->Var("feed");
target = &(
feed_var->template GetMutable<framework::LoDTensorArray>()->at(index));
} else {
auto *target_var = program_.scope->FindVar(var_name);
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist",
var_name.c_str());
target = target_var->template GetMutable<LoDTensor>();
index = feed_indices_.find(var_name)->second;
}
auto *feed_var = program_.scope->Var("feed");
framework::LoDTensor &target =
feed_var->template GetMutable<framework::LoDTensorArray>()->at(index);
if (config_.load_when_predict) {
if (input_dim_last_ != input.dims()) {
InitNoPersistableMemory(input);
......@@ -397,25 +393,21 @@ void Executor<Device, T>::SetInput(const Tensor &input,
}
}
target->Resize(input.dims());
target->ShareDataWith(input);
target.Resize(input.dims());
target.ShareDataWith(input);
}
template <typename Device, typename T>
void Executor<Device, T>::SetInput(const LoDTensor &input,
const std::string &var_name) {
framework::LoDTensor *target = nullptr;
int index = 0;
if (feed_indices_.find(var_name) != feed_indices_.end()) {
int index = feed_indices_.find(var_name)->second;
auto *feed_var = program_.scope->Var("feed");
target = &(
feed_var->template GetMutable<framework::LoDTensorArray>()->at(index));
} else {
auto *target_var = program_.scope->FindVar(var_name);
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist",
var_name.c_str());
target = target_var->template GetMutable<LoDTensor>();
index = feed_indices_.find(var_name)->second;
}
auto *feed_var = program_.scope->Var("feed");
framework::LoDTensor &target =
feed_var->template GetMutable<framework::LoDTensorArray>()->at(index);
if (config_.load_when_predict) {
if (input_dim_last_ != input.dims()) {
InitNoPersistableMemory(input);
......@@ -423,27 +415,23 @@ void Executor<Device, T>::SetInput(const LoDTensor &input,
}
}
target->Resize(input.dims());
target->ShareDataWith(input);
target->set_lod(input.lod());
target.Resize(input.dims());
target.ShareDataWith(input);
target.set_lod(input.lod());
}
template <typename Device, typename T>
std::shared_ptr<LoDTensor> Executor<Device, T>::GetOutput(
const std::string &var_name) {
framework::LoDTensor *target = nullptr;
int index = 0;
if (fetch_indices_.find(var_name) != fetch_indices_.end()) {
int index = fetch_indices_.find(var_name)->second;
auto *fetch_var = program_.scope->Var("fetch");
target = &(
fetch_var->template GetMutable<framework::LoDTensorArray>()->at(index));
} else {
auto *target_var = program_.scope->FindVar(var_name);
PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist",
var_name.c_str());
target = target_var->template GetMutable<LoDTensor>();
index = fetch_indices_.find(var_name)->second;
}
return std::make_shared<LoDTensor>(*target);
auto *fetch_var = program_.scope->Var("fetch");
framework::LoDTensor &target =
fetch_var->template GetMutable<framework::LoDTensorArray>()->at(index);
return std::make_shared<LoDTensor>(target);
}
template <typename Device, typename T>
......
......@@ -327,3 +327,6 @@ LOAD_OP1(roi_perspective_transform, CPU);
#ifdef BEAM_SEARCH_DECODE_OP
LOAD_OP1(beam_search_decode, CPU);
#endif
#ifdef PAD2D_OP
LOAD_OP1(pad2d, CPU);
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PAD2D_OP
#include "operators/kernel/pad2d_kernel.h"
#include "operators/math/pad.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Pad2DKernel<CPU, float>::Init(Pad2DParam<CPU> *param) {
return true;
}
template <>
void Pad2DKernel<CPU, float>::Compute(const Pad2DParam<CPU> &param) {
const auto *input = param.input_;
auto *output = param.output_;
const auto &paddings = param.paddings_;
// if (param.mode_ == "constant" && param.pad_value_ == 0) {
math::PadFunctor<CPU, float> pad;
pad(*input, paddings[0], paddings[1], paddings[2], paddings[3], output);
// } else {
// PADDLE_MOBILE_THROW_EXCEPTION("Pad2D has not been implemented.");
// }
output->set_lod(input->lod());
}
} // namespace operators
} // namespace paddle_mobile
#endif // PAD2D_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PAD2D_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype>
class Pad2DParam : public OpParam {
public:
Pad2DParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, scope);
paddings_ = OpParam::GetAttr<std::vector<int>>("paddings", attrs);
pad_value_ = OpParam::GetAttr<float>("pad_value", attrs);
mode_ = OpParam::GetStringAttr("mode", attrs);
}
public:
framework::LoDTensor *input_;
framework::LoDTensor *output_;
std::vector<int> paddings_;
float pad_value_;
std::string mode_;
};
DECLARE_KERNEL(Pad2D, Pad2DParam);
} // namespace operators
} // namespace paddle_mobile
#endif // PAD2D_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PAD2D_OP
#include "operators/pad2d.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void Pad2DOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.input_->dims();
const auto &paddings = this->param_.paddings_;
PADDLE_MOBILE_ENFORCE(paddings.size() == 4,
"Size of paddings should be equal to 4.");
input_dims[2] += paddings[0] + paddings[1];
input_dims[3] += paddings[2] + paddings[3];
this->param_.output_->Resize(input_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(pad2d, ops::Pad2DOp);
#endif
#endif // PAD2D_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PAD2D_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/pad2d_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
DECLARE_OPERATOR(Pad2D, Pad2DParam, Pad2DKernel);
} // namespace operators
} // namespace paddle_mobile
#endif // PAD2D_OP
......@@ -300,6 +300,7 @@ if(NOT FOUND_MATCH)
set(PSROI_POOL_OP ON)
set(ROI_PERSPECTIVE_OP ON)
set(BEAM_SEARCH_DECODE_OP ON)
set(PAD2D_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -606,3 +607,6 @@ endif()
if (BEAM_SEARCH_DECODE_OP)
add_definitions(-DBEAM_SEARCH_DECODE_OP)
endif()
if (PAD2D_OP)
add_definitions(-DPAD2D_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册