提交 57021c48 编写于 作者: Z zp7 提交者: Yanzhan Yang

add assign,equal,fill_constant_batch_size_like,range,reduce_prod,cond… …...

add assign,equal,fill_constant_batch_size_like,range,reduce_prod,cond…  … …itional_block cpu op (#1783)

* add assign,equal,fill_constant_batch_size_like,range,reduce_prod,conditional_block cpu op

* 1.fix roi_perspective_transform,lod_reset,multiclass_nms,slice,while op
2.elementwise_add compute support int input
3.scale compute support int64_t input
上级 69bbc70e
...@@ -122,6 +122,15 @@ const char *G_OP_TYPE_PAD2D = "pad2d"; ...@@ -122,6 +122,15 @@ const char *G_OP_TYPE_PAD2D = "pad2d";
const char *G_OP_TYPE_FUSION_DECONV_ADD_BN_RELU = "fusion_deconv_add_bn_relu"; const char *G_OP_TYPE_FUSION_DECONV_ADD_BN_RELU = "fusion_deconv_add_bn_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD_BN = "fusion_deconv_add_bn"; const char *G_OP_TYPE_FUSION_DECONV_ADD_BN = "fusion_deconv_add_bn";
const char *G_OP_TYPE_FUSION_DECONV_BN_RELU = "fusion_deconv_bn_relu"; const char *G_OP_TYPE_FUSION_DECONV_BN_RELU = "fusion_deconv_bn_relu";
const char *G_OP_TYPE_ASSIGN = "assign";
const char *G_OP_TYPE_REDUCE_PROD = "reduce_prod";
const char *G_OP_TYPE_EQUAL = "equal";
const char *G_OP_TYPE_CONDITIONAL_BLOCK = "conditional_block";
const char *G_OP_TYPE_RANGE = "range";
const char *G_OP_TYPE_WHILE = "while";
const char *G_OP_TYPE_BEAM_SEARCH_DECODE = "beam_search_decode";
const char *G_OP_TYPE_FILL_CONSTAN_BATCH_SIZE_LIKE =
"fill_constant_batch_size_like";
std::unordered_map< std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
...@@ -234,5 +243,14 @@ std::unordered_map< ...@@ -234,5 +243,14 @@ std::unordered_map<
{G_OP_TYPE_FUSION_DECONV_ADD_BN_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD_BN, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD_BN, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_BN_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_REDUCE_PROD, {{"X"}, {"Out"}}},
{G_OP_TYPE_ASSIGN, {{"X"}, {"Out"}}},
{G_OP_TYPE_EQUAL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_RANGE, {{"Start", "End", "Step"}, {"Out"}}},
{G_OP_TYPE_CONDITIONAL_BLOCK, {{"Input", "Cond"}, {"Out", "Scope"}}},
{G_OP_TYPE_WHILE, {{"Condition", "X"}, {"Out", "StepScopes"}}},
{G_OP_TYPE_BEAM_SEARCH_DECODE,
{{"Ids", "Scores"}, {"SentenceIds", "SentenceScores"}}},
{G_OP_TYPE_FILL_CONSTAN_BATCH_SIZE_LIKE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_PAD2D, {{"X"}, {"Out"}}}}; {G_OP_TYPE_PAD2D, {{"X"}, {"Out"}}}};
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -353,3 +353,21 @@ LOAD_OP1(assign_value, CPU); ...@@ -353,3 +353,21 @@ LOAD_OP1(assign_value, CPU);
#ifdef EXP_OP #ifdef EXP_OP
LOAD_OP1(exp, CPU); LOAD_OP1(exp, CPU);
#endif #endif
#ifdef ASSIGN_OP
LOAD_OP1(assign, CPU);
#endif
#ifdef CONDITIONAL_BLOCK_OP
LOAD_OP1(conditional_block, CPU);
#endif
#ifdef EQUAL_OP
LOAD_OP1(equal, CPU);
#endif
#ifdef FILL_CONSTANT_BATCH_SIZE_LIKE_OP
LOAD_OP1(fill_constant_batch_size_like, CPU);
#endif
#ifdef RANGE_OP
LOAD_OP1(range, CPU);
#endif
#ifdef REDUCE_PROD_OP
LOAD_OP1(reduce_prod, CPU);
#endif
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "framework/operator.h"
#include "framework/program/program_desc.h" #include "framework/program/program_desc.h"
#include "framework/program/tensor_desc.h" #include "framework/program/tensor_desc.h"
...@@ -51,6 +52,17 @@ void ProgramDesc::Description(std::string header) const { ...@@ -51,6 +52,17 @@ void ProgramDesc::Description(std::string header) const {
if (header.size()) { if (header.size()) {
LOG(kLOG_INFO) << header; LOG(kLOG_INFO) << header;
} }
for (int i = 0; i < this->blocks_.size(); ++i) {
auto block = this->blocks_[i];
for (int j = 0; j < block->Ops().size(); ++j) {
std::shared_ptr<OpDesc> op_desc = block->Ops()[j];
auto op_info_ptr =
OpInfoMap<CPU>::Instance()->GetNullable(op_desc->Type());
if (op_info_ptr == nullptr) {
DLOG << "Operator has not been registered :" << op_desc->Type().c_str();
}
}
}
for (int i = 0; i < this->blocks_.size(); ++i) { for (int i = 0; i < this->blocks_.size(); ++i) {
auto block = this->blocks_[i]; auto block = this->blocks_[i];
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ASSIGN_OP
#include "operators/assign_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void AssignOp<Dtype, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.Input() != nullptr,
"Input (X) of Assign op should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Output() != nullptr,
"Output (Output) of Assign op should not be null.");
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(assign, ops::AssignOp);
#endif
#endif // ASSIGN_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ASSIGN_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/assign_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
DECLARE_OPERATOR(Assign, AssignParam, AssignKernel);
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -25,6 +25,14 @@ void LessThanOp<Dtype, T>::InferShape() const { ...@@ -25,6 +25,14 @@ void LessThanOp<Dtype, T>::InferShape() const {
} }
#endif // LESS_THAN_OP #endif // LESS_THAN_OP
#ifdef EQUAL_OP
template <typename Dtype, typename T>
void EqualOp<Dtype, T>::InferShape() const {
const auto &input_dims = this->param_.input_x_->dims();
this->param_.output_->Resize(input_dims);
}
#endif // EQUAL_OP
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -32,3 +40,6 @@ namespace ops = paddle_mobile::operators; ...@@ -32,3 +40,6 @@ namespace ops = paddle_mobile::operators;
#ifdef LESS_THAN_OP #ifdef LESS_THAN_OP
REGISTER_OPERATOR_CPU(less_than, ops::LessThanOp); REGISTER_OPERATOR_CPU(less_than, ops::LessThanOp);
#endif // LESS_THAN_OP #endif // LESS_THAN_OP
#ifdef EQUAL_OP
REGISTER_OPERATOR_CPU(equal, ops::EqualOp);
#endif // EQUAL_OP
...@@ -26,5 +26,9 @@ namespace operators { ...@@ -26,5 +26,9 @@ namespace operators {
DECLARE_OPERATOR(LessThan, CompareParam, LessThanKernel); DECLARE_OPERATOR(LessThan, CompareParam, LessThanKernel);
#endif // LESS_THAN_OP #endif // LESS_THAN_OP
#ifdef EQUAL_OP
DECLARE_OPERATOR(Equal, CompareParam, EqualKernel);
#endif // EQUAL_OP
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONDITIONAL_BLOCK_OP
#include "operators/conditional_block_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ConditionalBlockOp<Dtype, T>::InferShape() const {}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(conditional_block, ops::ConditionalBlockOp);
#endif
#endif // CONDITIONAL_BLOCK_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONDITIONAL_BLOCK_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/conditional_block_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
DECLARE_OPERATOR(ConditionalBlock, ConditionalBlockParam,
ConditionalBlockKernel);
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -95,6 +95,17 @@ void RoiPerspectiveOp<DeviceType, T>::InferShape() const { ...@@ -95,6 +95,17 @@ void RoiPerspectiveOp<DeviceType, T>::InferShape() const {
static_cast<int64_t>(transformed_width)}); static_cast<int64_t>(transformed_width)});
auto out_dims = framework::make_ddim(out_dims_v); auto out_dims = framework::make_ddim(out_dims_v);
this->param_.output_->Resize(out_dims); this->param_.output_->Resize(out_dims);
std::vector<int64_t> mask_dims_v({rois_dims[0], // num_rois
1, // channels
static_cast<int64_t>(transformed_height),
static_cast<int64_t>(transformed_width)});
auto mask_dims = framework::make_ddim(mask_dims_v);
std::vector<int64_t> matrix_dims_v({rois_dims[0], 9});
auto matrix_dims = framework::make_ddim(matrix_dims_v);
this->param_.transform_Matrix_->Resize(matrix_dims);
this->param_.mask->Resize(mask_dims);
} }
#endif #endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FILL_CONSTANT_BATCH_SIZE_LIKE_OP
#include "operators/fill_constant_batch_size_like_op.h"
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FILL_CONSTANT_BATCH_SIZE_LIKE_OP
#pragma once
#include <algorithm>
#include <string>
#include "framework/data_type.h"
#include "framework/operator.h"
#include "framework/selected_rows.h"
#include "operators/math/math_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FillConstantBatchSizeLikeOp : public framework::OperatorBase<DeviceType> {
public:
FillConstantBatchSizeLikeOp(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
framework::Scope *scope)
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, scope) {}
void RunImpl() {
auto data_type =
static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(
param_.DataDtype());
framework::Tensor *tensor = nullptr;
auto value = param_.Value();
auto *outvar = param_.OutVar();
if (outvar->template IsType<framework::LoDTensor>()) {
tensor = outvar->template GetMutable<framework::LoDTensor>();
} else if (outvar->template IsType<framework::SelectedRows>()) {
tensor = outvar->template GetMutable<framework::SelectedRows>()
->mutable_value();
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"fill constant batch size like op's output only"
"supports SelectedRows and LoDTensor");
}
auto shape = param_.Shape();
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto ddim = framework::make_ddim(shape_int64);
ddim[param_.OutputDimIdx()] = param_.Input()->dims()[param_.InputDimIdx()];
tensor->Resize(ddim);
tensor->mutable_data(framework::ToTypeIndex(data_type));
math::SetConstant(tensor, value);
}
void Init() {}
void InferShape() const {
PADDLE_MOBILE_ENFORCE(
param_.Out() != nullptr,
"Output (Out) of fill_constant_batch_size_like op should not be null.");
auto shape = param_.Shape();
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
DLOG << shape_int64;
auto ddim = framework::make_ddim(shape_int64);
ddim[param_.OutputDimIdx()] = param_.Input()->dims()[param_.InputDimIdx()];
param_.Out()->Resize(ddim);
}
protected:
FillConstantBatchSizeLikeParam<DeviceType> param_;
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ASSIGN_OP
#include "operators/kernel/assign_kernel.h"
#include "framework/data_type.h"
namespace paddle_mobile {
namespace operators {
template <>
bool AssignKernel<CPU, float>::Init(AssignParam<CPU>* param) {
return true;
}
template <>
void AssignKernel<CPU, float>::Compute(const AssignParam<CPU>& param) {
const auto* input = param.Input();
auto* out = param.Output();
out->mutable_data<float>();
framework::TensorCopy(*input, out);
}
} // namespace operators
} // namespace paddle_mobile
#endif // ASSIGN_OP
...@@ -41,6 +41,11 @@ inline uint8_t Compare(const float x, const float y) { ...@@ -41,6 +41,11 @@ inline uint8_t Compare(const float x, const float y) {
return static_cast<uint8_t>(x < y); return static_cast<uint8_t>(x < y);
} }
template <CompareType Comp = EQUAL>
inline uint8_t Compare(const int x, const int y) {
return static_cast<uint8_t>(x == y);
}
template <CompareType Comp = LESS_THAN> template <CompareType Comp = LESS_THAN>
inline uint8_t Compare(const int64_t x, const int64_t y) { inline uint8_t Compare(const int64_t x, const int64_t y) {
return static_cast<uint8_t>(x < y); return static_cast<uint8_t>(x < y);
...@@ -184,6 +189,51 @@ struct CompareCompute<int64_t, Comp> { ...@@ -184,6 +189,51 @@ struct CompareCompute<int64_t, Comp> {
} }
}; };
template <CompareType Comp>
struct CompareCompute<int, Comp> {
void operator()(const Tensor *X, const Tensor *Y, const int Axis,
Tensor *Out) {
const int *x = X->data<int>();
const int *y = Y->data<int>();
uint8_t *output = reinterpret_cast<uint8_t *>(Out->mutable_data<bool>());
const auto &x_dims = X->dims();
const auto &y_dims = Y->dims();
/// axis = -1 represent the last dimensions.
int axis = (Axis == -1 ? x_dims.size() - y_dims.size() : Axis);
int batch = 1;
int channels = 1;
int elementwise_num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
elementwise_num *= x_dims[i];
}
// if elementwise_num == 1, compare rowwise
if (elementwise_num == 1) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int x_offset = i * channels + j;
output[x_offset] = Compare<Comp>(x[x_offset], y[j]);
}
}
} else {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int x_offset = (i * channels + j) * elementwise_num;
int y_offset = j * elementwise_num;
for (int k = 0; k < elementwise_num; ++k) {
output[x_offset + k] = Compare<Comp>(x[x_offset + k], y[y_offset]);
}
}
}
}
}
};
#ifdef LESS_THAN_OP #ifdef LESS_THAN_OP
template <> template <>
bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) { bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
...@@ -205,5 +255,20 @@ void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) { ...@@ -205,5 +255,20 @@ void LessThanKernel<CPU, float>::Compute(const CompareParam<CPU> &param) {
} }
#endif // LESS_THAN_OP #endif // LESS_THAN_OP
#ifdef EQUAL_OP
template <>
bool EqualKernel<CPU, float>::Init(CompareParam<CPU> *param) {
return true;
}
template <>
void EqualKernel<CPU, float>::Compute(const CompareParam<CPU> &param) {
if (param.input_x_->type() == type_id<int>().hash_code()) {
CompareCompute<int, EQUAL>()(param.input_x_, param.input_y_, param.axis_,
param.output_);
}
}
#endif // EQUAL_OP
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONDITIONAL_BLOCK_OP
#include "operators/kernel/conditional_block_kernel.h"
#include <framework/program/block_desc.h>
#include <framework/program/op_desc.h>
#include "framework/data_type.h"
namespace paddle_mobile {
namespace operators {
class StepExecutor {
typedef std::shared_ptr<framework::OperatorBase<CPU>> OperatorPtr;
public:
StepExecutor(const framework::BlockDesc *block, framework::Scope *scope)
: scope_(scope) {
std::vector<std::shared_ptr<framework::OpDesc>> ops = block->Ops();
ops_of_block_.resize(ops.size());
for (int i = 0; i < ops.size(); ++i) {
std::shared_ptr<framework::OpDesc> op_desc = ops[i];
DLOG << "conditional block create op: " << ops.size() << ","
<< op_desc->Type();
auto op_handler = framework::OpRegistry<CPU>::CreateOp(
op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(),
op_desc->GetAttrMap(), scope_);
op_handler->Init();
ops_of_block_[i] = op_handler;
}
}
void Run() {
for (int i = 0; i < ops_of_block_.size(); ++i) {
auto &op_handler = ops_of_block_[i];
DLOG << "conditional block op InferShape: " << i
<< "th: " << op_handler->Type();
op_handler->InferShape();
DLOG << "conditional block op Run: " << i << "th: " << op_handler->Type();
op_handler->Run();
}
}
private:
framework::Scope *scope_;
std::vector<OperatorPtr> ops_of_block_;
};
template <>
bool ConditionalBlockKernel<CPU, float>::Init(
ConditionalBlockParam<CPU> *param) {
return true;
}
template <>
void ConditionalBlockKernel<CPU, float>::Compute(
const ConditionalBlockParam<CPU> &param) {
bool need_run;
if (param.isScalarCondition()) {
auto xs = param.Cond();
PADDLE_MOBILE_ENFORCE(
xs[0]->type() == type_id<bool>().hash_code() && xs[0]->numel() == 1,
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
xs[0]->numel());
need_run = xs[0]->data<bool>()[0];
} else {
auto xs = param.Input();
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}
if (need_run) {
auto input = param.Input();
auto sub = param.getSubBlock();
auto &current_scope = param.GetScope()->NewScope();
StepExecutor executor(sub, &current_scope);
executor.Run();
param.GetScope()->DeleteScope(&current_scope);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // CONDITIONAL_BLOCK_OP
...@@ -28,7 +28,12 @@ bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam<CPU> *param) { ...@@ -28,7 +28,12 @@ bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam<CPU> *param) {
template <> template <>
void ElementwiseAddKernel<CPU, float>::Compute( void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam<CPU> &param) { const ElementwiseAddParam<CPU> &param) {
ElementwiseAddCompute<float>(param); if (param.InputX()->type() == type_id<float>().hash_code()) {
ElementwiseAddCompute<float>(param);
} else if (param.InputX()->type() == type_id<int>().hash_code()) {
AddElememtWiseStruct<int, IDENTITY>()(param.InputX(), param.InputY(),
param.Axis(), param.Out());
}
param.Out()->set_lod(param.InputX()->lod()); param.Out()->set_lod(param.InputX()->lod());
} }
......
...@@ -28,6 +28,7 @@ template <> ...@@ -28,6 +28,7 @@ template <>
void LodResetKernel<CPU, float>::Compute(const LodResetParam<CPU> &param) { void LodResetKernel<CPU, float>::Compute(const LodResetParam<CPU> &param) {
const auto *input = param.input_x_; const auto *input = param.input_x_;
const auto *lod_t = param.input_y_; const auto *lod_t = param.input_y_;
bool append = param.append;
auto *output = param.output_; auto *output = param.output_;
output->ShareDataWith(*input); output->ShareDataWith(*input);
...@@ -47,13 +48,17 @@ void LodResetKernel<CPU, float>::Compute(const LodResetParam<CPU> &param) { ...@@ -47,13 +48,17 @@ void LodResetKernel<CPU, float>::Compute(const LodResetParam<CPU> &param) {
// cast level0 to size_t // cast level0 to size_t
std::vector<size_t> ulevel0(level0.size(), 0); std::vector<size_t> ulevel0(level0.size(), 0);
for (int i = 0; i < level0.size(); ++i) { std::transform(level0.begin(), level0.end(), ulevel0.begin(),
ulevel0[i] = level0[i]; [](int a) { return static_cast<size_t>(a); });
}
framework::LoD target_lod; if (append) {
target_lod.push_back(std::move(ulevel0)); auto *out_lod = output->mutable_lod();
output->set_lod(target_lod); out_lod->push_back(ulevel0);
} else {
framework::LoD target_lod;
target_lod.push_back(ulevel0);
output->set_lod(target_lod);
}
} }
} // namespace operators } // namespace operators
......
...@@ -208,6 +208,8 @@ void RoiPerspectiveKernel<CPU, float>::Compute( ...@@ -208,6 +208,8 @@ void RoiPerspectiveKernel<CPU, float>::Compute(
const auto *input_x = param.input_x_; const auto *input_x = param.input_x_;
const auto *input_rois = param.input_rois_; const auto *input_rois = param.input_rois_;
auto *output = param.output_; auto *output = param.output_;
auto *transform_Matrix = param.transform_Matrix_;
auto *mask = param.mask;
const auto &in_dims = input_x->dims(); const auto &in_dims = input_x->dims();
const int channels = in_dims[1]; const int channels = in_dims[1];
...@@ -221,6 +223,9 @@ void RoiPerspectiveKernel<CPU, float>::Compute( ...@@ -221,6 +223,9 @@ void RoiPerspectiveKernel<CPU, float>::Compute(
const float *input_data = input_x->data<float>(); const float *input_data = input_x->data<float>();
const float *rois_data = input_rois->data<float>(); const float *rois_data = input_rois->data<float>();
float *output_data = output->mutable_data<float>(); float *output_data = output->mutable_data<float>();
int *mask_data = mask->mutable_data<int>();
float *transform_matrix =
transform_Matrix->mutable_data<float>({rois_num, 9});
std::vector<int> roi2image(rois_num); std::vector<int> roi2image(rois_num);
const auto &lod = input_rois->lod().back(); const auto &lod = input_rois->lod().back();
...@@ -240,9 +245,13 @@ void RoiPerspectiveKernel<CPU, float>::Compute( ...@@ -240,9 +245,13 @@ void RoiPerspectiveKernel<CPU, float>::Compute(
} }
int image_id = roi2image[n]; int image_id = roi2image[n];
// Get transform matrix // Get transform matrix
float transform_matrix[9]; // float transform_matrix[9];
float matrix[9];
get_transform_matrix<float>(transformed_width, transformed_height, roi_x, get_transform_matrix<float>(transformed_width, transformed_height, roi_x,
roi_y, transform_matrix); roi_y, matrix);
for (int i = 0; i < 9; i++) {
transform_matrix[n * 9 + i] = matrix[i];
}
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
for (int out_h = 0; out_h < transformed_height; ++out_h) { for (int out_h = 0; out_h < transformed_height; ++out_h) {
for (int out_w = 0; out_w < transformed_width; ++out_w) { for (int out_w = 0; out_w < transformed_width; ++out_w) {
...@@ -251,19 +260,24 @@ void RoiPerspectiveKernel<CPU, float>::Compute( ...@@ -251,19 +260,24 @@ void RoiPerspectiveKernel<CPU, float>::Compute(
c * transformed_height * transformed_width + c * transformed_height * transformed_width +
out_h * transformed_width + out_w; out_h * transformed_width + out_w;
float in_w, in_h; float in_w, in_h;
get_source_coords<float>(transform_matrix, out_w, out_h, &in_w, get_source_coords<float>(matrix, out_w, out_h, &in_w, &in_h);
&in_h);
if (in_quad<float>(in_w, in_h, roi_x, roi_y)) { if (in_quad<float>(in_w, in_h, roi_x, roi_y)) {
if ((-0.5 > in_w) || (in_w > (in_width - 0.5)) || (-0.5 > in_h) || if ((-0.5 > in_w) || (in_w > (in_width - 0.5)) || (-0.5 > in_h) ||
(in_h > (in_height - 0.5))) { (in_h > (in_height - 0.5))) {
output_data[out_index] = 0.0; output_data[out_index] = 0.0;
mask_data[(n * transformed_height + out_h) * transformed_width +
out_w] = 0;
} else { } else {
bilinear_interpolate<float>(input_data, channels, in_width, bilinear_interpolate<float>(input_data, channels, in_width,
in_height, image_id, c, in_w, in_h, in_height, image_id, c, in_w, in_h,
output_data + out_index); output_data + out_index);
mask_data[(n * transformed_height + out_h) * transformed_width +
out_w] = 1;
} }
} else { } else {
output_data[out_index] = 0.0; output_data[out_index] = 0.0;
mask_data[(n * transformed_height + out_h) * transformed_width +
out_w] = 1;
} }
} }
} }
......
...@@ -33,39 +33,49 @@ void ScaleKernel<CPU, float>::Compute(const ScaleParam<CPU> &param) { ...@@ -33,39 +33,49 @@ void ScaleKernel<CPU, float>::Compute(const ScaleParam<CPU> &param) {
auto output = param.Out(); auto output = param.Out();
const float scale = param.Scale(); const float scale = param.Scale();
const float bias = param.Bias(); const float bias = param.Bias();
const float *input_data = input->data<float>(); if (input->type() == type_id<int64_t>().hash_code()) {
float *output_data = output->mutable_data<float>(); const int64_t *input_data = input->data<int64_t>();
int64_t *output_data = output->mutable_data<int64_t>();
int i = 0; int i = 0;
for (; i < output->numel(); ++i, ++output_data, ++input_data) {
*output_data = scale * (*input_data) + bias;
}
} else {
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
int i = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4_t vscale = vdupq_n_f32(scale); float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias); float32x4_t vbias = vdupq_n_f32(bias);
for (; i < output->numel() - 15; i += 16) { for (; i < output->numel() - 15; i += 16) {
float32x4_t _in0 = vld1q_f32(input_data); float32x4_t _in0 = vld1q_f32(input_data);
float32x4_t _in1 = vld1q_f32(input_data + 4); float32x4_t _in1 = vld1q_f32(input_data + 4);
float32x4_t _in2 = vld1q_f32(input_data + 8); float32x4_t _in2 = vld1q_f32(input_data + 8);
float32x4_t _in3 = vld1q_f32(input_data + 12); float32x4_t _in3 = vld1q_f32(input_data + 12);
_in0 = vmlaq_f32(vbias, vscale, _in0); _in0 = vmlaq_f32(vbias, vscale, _in0);
_in1 = vmlaq_f32(vbias, vscale, _in1); _in1 = vmlaq_f32(vbias, vscale, _in1);
_in2 = vmlaq_f32(vbias, vscale, _in2); _in2 = vmlaq_f32(vbias, vscale, _in2);
_in3 = vmlaq_f32(vbias, vscale, _in3); _in3 = vmlaq_f32(vbias, vscale, _in3);
vst1q_f32(output_data, _in0); vst1q_f32(output_data, _in0);
vst1q_f32(output_data + 4, _in1); vst1q_f32(output_data + 4, _in1);
vst1q_f32(output_data + 8, _in2); vst1q_f32(output_data + 8, _in2);
vst1q_f32(output_data + 12, _in3); vst1q_f32(output_data + 12, _in3);
input_data += 16; input_data += 16;
output_data += 16; output_data += 16;
} }
for (; i < output->numel() - 3; i += 4) { for (; i < output->numel() - 3; i += 4) {
float32x4_t _in0 = vld1q_f32(input_data); float32x4_t _in0 = vld1q_f32(input_data);
_in0 = vmlaq_f32(vbias, vscale, _in0); _in0 = vmlaq_f32(vbias, vscale, _in0);
vst1q_f32(output_data, _in0); vst1q_f32(output_data, _in0);
input_data += 4; input_data += 4;
output_data += 4; output_data += 4;
} }
#endif #endif
for (; i < output->numel(); ++i, ++output_data, ++input_data) { for (; i < output->numel(); ++i, ++output_data, ++input_data) {
*output_data = scale * (*input_data) + bias; *output_data = scale * (*input_data) + bias;
}
} }
} }
......
...@@ -19,18 +19,22 @@ limitations under the License. */ ...@@ -19,18 +19,22 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename Dtype>
void SliceCompute(const SliceParam<CPU>& param) { void SliceCompute(const SliceParam<CPU>& param) {
auto input = param.input_; auto input = param.input_;
auto output = param.output_; auto output = param.output_;
auto* input_ptr = input->data<float>(); auto* input_ptr = input->data<Dtype>();
auto* output_ptr = output->mutable_data<float>(); auto* output_ptr = output->mutable_data<Dtype>();
auto out_dims = output->dims(); auto out_dims = output->dims();
auto in_dims = input->dims(); auto in_dims = input->dims();
auto starts = param.starts_; auto starts = param.starts_;
auto ends = param.ends_; auto ends = param.ends_;
int axes = param.axes_[0]; int axes = param.axes_[0];
int HW = input->dims()[axes + 1] * input->dims()[axes + 2]; int HW = 1;
int batch_size = out_dims[axes - 1]; if (in_dims.size() >= 2 && axes <= in_dims.size() - 2) {
HW = in_dims[axes + 1] * input->dims()[axes + 2];
}
int batch_size = (out_dims.size() == 1) ? 1 : out_dims[axes - 1];
int input_channel = in_dims[axes]; int input_channel = in_dims[axes];
int output_channel = out_dims[axes]; int output_channel = out_dims[axes];
...@@ -53,12 +57,22 @@ template <> ...@@ -53,12 +57,22 @@ template <>
void SliceKernel<CPU, float>::Compute(const SliceParam<CPU>& param) { void SliceKernel<CPU, float>::Compute(const SliceParam<CPU>& param) {
int rank = param.input_->dims().size(); int rank = param.input_->dims().size();
switch (rank) { switch (rank) {
case 1:
if (param.input_->type() == type_id<int>().hash_code()) {
SliceCompute<int>(param);
} else if (param.input_->type() == type_id<float>().hash_code()) {
SliceCompute<float>(param);
}
break;
case 2:
SliceCompute<float>(param);
break;
case 4: case 4:
SliceCompute(param); SliceCompute<float>(param);
break; break;
case 5: case 5:
if (param.input_->dims()[0] == 1) { if (param.input_->dims()[0] == 1) {
SliceCompute(param); SliceCompute<float>(param);
} }
break; break;
default: default:
......
...@@ -15,23 +15,25 @@ limitations under the License. */ ...@@ -15,23 +15,25 @@ limitations under the License. */
#ifdef WHILE_OP #ifdef WHILE_OP
#include "operators/kernel/while_kernel.h" #include "operators/kernel/while_kernel.h"
#include "framework/loader.h"
#include "framework/lod_tensor.h"
#include "framework/op_registry.h" #include "framework/op_registry.h"
#include "framework/operator.h" #include "framework/operator.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
class StepExecutor { class WhileStepExecutor {
typedef std::shared_ptr<framework::OperatorBase<CPU>> OperatorPtr; typedef std::shared_ptr<framework::OperatorBase<CPU>> OperatorPtr;
public: public:
StepExecutor(const framework::BlockDesc *block, framework::Scope *scope) WhileStepExecutor(const framework::BlockDesc *block, framework::Scope *scope)
: scope_(scope) { : scope_(scope) {
std::vector<std::shared_ptr<framework::OpDesc>> ops = block->Ops(); std::vector<std::shared_ptr<framework::OpDesc>> ops = block->Ops();
ops_of_block_.resize(ops.size()); ops_of_block_.resize(ops.size());
for (int i = 0; i < ops.size(); ++i) { for (int i = 0; i < ops.size(); ++i) {
std::shared_ptr<framework::OpDesc> op_desc = ops[i]; std::shared_ptr<framework::OpDesc> op_desc = ops[i];
DLOG << "create op: " << op_desc->Type(); DLOG << "while kernel create op: " << op_desc->Type();
auto op_handler = framework::OpRegistry<CPU>::CreateOp( auto op_handler = framework::OpRegistry<CPU>::CreateOp(
op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(), op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(),
op_desc->GetAttrMap(), scope_); op_desc->GetAttrMap(), scope_);
...@@ -41,12 +43,46 @@ class StepExecutor { ...@@ -41,12 +43,46 @@ class StepExecutor {
} }
void Run() { void Run() {
for (auto &op_handler : ops_of_block_) { for (int i = 0; i < ops_of_block_.size(); ++i) {
auto &op_handler = ops_of_block_[i];
DLOG << "while kernel InferShape op: " << i
<< "th : " << op_handler->Type();
op_handler->InferShape(); op_handler->InferShape();
DLOG << "while kernel Run op: " << i << "th : " << op_handler->Type();
op_handler->Run(); op_handler->Run();
} }
} }
void CreateVariables(Scope &scope, const WhileParam<CPU> &param) {
for (const auto &var_desc : param.sub_block_->Vars()) {
auto var = scope.Var(var_desc->Name());
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
if (var_desc->Persistable()) {
auto dim = var_desc->Tensor_desc().Dims();
auto tensor = var->framework::Variable::GetMutable<LoDTensor>();
tensor->Resize(framework::make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
if (dim.size() == 0) {
auto tensor = var->framework::Variable::GetMutable<LoDTensor>();
framework::DDim dDim = {0};
tensor->Resize(dDim);
} else {
for (auto &d : dim) {
if (d < 0) {
d *= -1;
}
}
auto tensor = var->framework::Variable::GetMutable<LoDTensor>();
tensor->Resize(framework::make_ddim(dim));
}
}
} else {
// TODO(codeWorm)
}
}
}
private: private:
framework::Scope *scope_; framework::Scope *scope_;
std::vector<OperatorPtr> ops_of_block_; std::vector<OperatorPtr> ops_of_block_;
...@@ -59,9 +95,28 @@ bool WhileKernel<CPU, float>::Init(WhileParam<CPU> *param) { ...@@ -59,9 +95,28 @@ bool WhileKernel<CPU, float>::Init(WhileParam<CPU> *param) {
template <> template <>
void WhileKernel<CPU, float>::Compute(const WhileParam<CPU> &param) { void WhileKernel<CPU, float>::Compute(const WhileParam<CPU> &param) {
DLOG << "WhileKernel Compute";
WhileStepExecutor executor(param.sub_block_, param.scope_);
auto &current_scope = param.scope_->NewScope(); auto &current_scope = param.scope_->NewScope();
StepExecutor executor(param.sub_block_, &current_scope); executor.CreateVariables(current_scope, param);
while (param.cond_->data<bool>()[0]) { while (param.cond_->data<bool>()[0]) {
if (param.is_test) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
if (var->IsType<framework::LoDTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<framework::LoDTensor>();
framework::LoD empty_lod;
t->set_lod(empty_lod);
} else if (var->IsType<framework::LoDTensorArray>()) {
// Clear elements of all tensor arrays.
auto *t = var->GetMutable<framework::LoDTensorArray>();
t->clear();
} else {
// todo
}
}
}
executor.Run(); executor.Run();
} }
param.scope_->DeleteScope(&current_scope); param.scope_->DeleteScope(&current_scope);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ASSIGN_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype>
class AssignParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
AssignParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = OpParam::InputXFrom<GType>(inputs, *scope);
output_ = OpParam::OutFrom<GType>(outputs, *scope);
}
const GType *Input() const { return input_; }
GType *Output() const { return output_; }
private:
GType *input_;
GType *output_;
};
DECLARE_KERNEL(Assign, AssignParam);
} // namespace operators
} // namespace paddle_mobile
#endif // ASSIGN_OP
...@@ -34,6 +34,42 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) { ...@@ -34,6 +34,42 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
math::AddElememtWise<IDENTITY>(input_x, input_y, axis, output); math::AddElememtWise<IDENTITY>(input_x, input_y, axis, output);
} }
template <typename Dtype, ActivationType Act>
struct AddElememtWiseStruct {
void operator()(const Tensor *X, const Tensor *Y, const int Axis,
Tensor *Out) {}
};
template <ActivationType Act>
struct AddElememtWiseStruct<int, Act> {
void operator()(const Tensor *input, const Tensor *bias, const int Axis,
Tensor *output) {
const auto &x_dims = input->dims();
const auto &y_dims = bias->dims();
const int *input_data = input->data<int>();
const int *bias_data = bias->data<int>();
int *output_data = output->mutable_data<int>();
if (x_dims == y_dims) {
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
#pragma omp parallel for
for (int j = 0; j < channels; ++j) {
size_t offset = (0 * channels + j) * elementwise_num;
const int *input = input_data + offset;
const int bias = bias_data[j];
int *output = output_data + offset;
for (int k = 0; k < elementwise_num; ++k) {
output[k] = math::Active<Act>(input[k] + bias);
}
}
}
}
};
template class ElementwiseAddKernel<CPU, float>; template class ElementwiseAddKernel<CPU, float>;
} // namespace operators } // namespace operators
......
...@@ -294,6 +294,11 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) { ...@@ -294,6 +294,11 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
} }
} }
} }
framework::LoD lod;
lod.emplace_back(batch_starts);
outs->set_lod(lod);
} }
} // namespace operators } // namespace operators
......
...@@ -24,5 +24,9 @@ namespace operators { ...@@ -24,5 +24,9 @@ namespace operators {
DECLARE_KERNEL(LessThan, CompareParam); DECLARE_KERNEL(LessThan, CompareParam);
#endif // LESS_THAN_OP #endif // LESS_THAN_OP
#ifdef EQUAL_OP
DECLARE_KERNEL(Equal, CompareParam);
#endif // EQUAL_OP
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONDITIONAL_BLOCK_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype>
class ConditionalBlockParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ConditionalBlockParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = OpParam::GetMultiVarValue<GType>("Input", inputs, *scope);
cond_ = OpParam::GetMultiVarValue<GType>("Cond", inputs, *scope);
output_ = OpParam::OutFrom<GType>(outputs, *scope);
scope_ = OpParam::GetVar("Scope", outputs, *scope);
is_scalar_condition_ = GetAttr<bool>("is_scalar_condition", attrs);
sub_block_ = GetAttr<framework::BlockDesc *>("sub_block", attrs);
}
const vector<GType *> Input() const { return input_; }
const vector<GType *> Cond() const { return cond_; }
GType *Output() const { return output_; }
Variable *OutputScope() const { return scope_; }
bool isScalarCondition() const { return is_scalar_condition_; }
framework::BlockDesc *getSubBlock() const { return sub_block_; }
private:
vector<GType *> input_;
vector<GType *> cond_;
GType *output_;
Variable *scope_;
bool is_scalar_condition_;
framework::BlockDesc *sub_block_;
};
DECLARE_KERNEL(ConditionalBlock, ConditionalBlockParam);
} // namespace operators
} // namespace paddle_mobile
#endif // CONDITIONAL_BLOCK_OP
...@@ -204,6 +204,9 @@ class RoiPerspectiveParam : public OpParam { ...@@ -204,6 +204,9 @@ class RoiPerspectiveParam : public OpParam {
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope); OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope);
output_ = output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope); OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
transform_Matrix_ = OpParam::GetVarValue<framework::LoDTensor>(
"TransformMatrix", outputs, *scope);
mask = OpParam::GetVarValue<framework::LoDTensor>("Mask", outputs, *scope);
spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs); spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs);
transformed_height_ = OpParam::GetAttr<int>("transformed_height", attrs); transformed_height_ = OpParam::GetAttr<int>("transformed_height", attrs);
...@@ -214,6 +217,8 @@ class RoiPerspectiveParam : public OpParam { ...@@ -214,6 +217,8 @@ class RoiPerspectiveParam : public OpParam {
framework::Tensor *input_x_; framework::Tensor *input_x_;
framework::LoDTensor *input_rois_; framework::LoDTensor *input_rois_;
framework::Tensor *output_; framework::Tensor *output_;
framework::Tensor *transform_Matrix_;
framework::Tensor *mask;
float spatial_scale_; float spatial_scale_;
int transformed_height_; int transformed_height_;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RANGE_OP
#include "operators/kernel/range_kernel.h"
#include "framework/data_type.h"
namespace paddle_mobile {
namespace operators {
template <>
bool RangeKernel<CPU, float>::Init(RangeParam<CPU>* param) {
return true;
}
template <>
void RangeKernel<CPU, float>::Compute(const RangeParam<CPU>& param) {
int start = param.Start()->data<int>()[0];
int end = param.End()->data<int>()[0];
int step = param.Step()->data<int>()[0];
auto* out = param.Output();
int64_t size = 0;
GetSize(start, end, step, &size);
out->Resize(framework::make_ddim({size}));
auto* out_data = out->mutable_data<int>();
auto value = start;
for (int64_t i = 0; i < size; ++i) {
out_data[i] = value;
value += step;
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // RANGE_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RANGE_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
inline void GetSize(float start, float end, float step, int64_t *size) {
PADDLE_MOBILE_ENFORCE(!std::equal_to<float>()(step, 0),
"The step of range op should not be 0.");
PADDLE_MOBILE_ENFORCE(
((start < end) && (step > 0)) || ((start > end) && (step < 0)),
"The step should be greater than 0 while start < end. And the "
"step should be less than 0 while start > end.");
*size = std::is_integral<float>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
}
template <typename Dtype>
class RangeParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
RangeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
start_ = OpParam::GetVarValue<GType>("Start", inputs, *scope);
end_ = OpParam::GetVarValue<GType>("End", inputs, *scope);
step_ = OpParam::GetVarValue<GType>("Step", inputs, *scope);
output_ = OpParam::OutFrom<GType>(outputs, *scope);
}
GType *Start() const { return start_; }
const GType *End() const { return end_; }
const GType *Step() const { return step_; }
GType *Output() const { return output_; }
private:
GType *start_;
GType *end_;
GType *step_;
GType *output_;
};
DECLARE_KERNEL(Range, RangeParam);
} // namespace operators
} // namespace paddle_mobile
#endif // RANGE_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef REDUCE_PROD_OP
#include "operators/kernel/reduce_prod_kernel.h"
#include <operators/reduce_prod_op.h>
#include <array>
#include "framework/data_type.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ReduceProdKernel<CPU, float>::Init(ReduceProdParam<CPU>* param) {
return true;
}
template <>
void ReduceProdKernel<CPU, float>::Compute(const ReduceProdParam<CPU>& param) {
auto* input = param.Input();
if (input->type() == type_id<int>().hash_code()) {
bool reduce_all = param.isReduceAll();
auto* output = param.Output();
auto dim = param.getDim();
auto* out_data = output->mutable_data<int>();
const auto* input_x_data = input->data<int>();
auto dims = param.getDim();
bool keep_dim = param.isKeepDim();
if (reduce_all) {
size_t stride = 1;
for (int j = dim[0]; j < input->dims().size(); ++j) {
stride *= input->dims()[j];
}
auto numel = output->numel();
for (int i = 0; i < numel; i++) {
int64_t mul = 1;
for (int j = 0; j < stride; ++j, ++input_x_data) {
mul *= (*input_x_data);
}
out_data[i] = mul;
}
} else {
// todo
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // REDUCE_PROD_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef REDUCE_PROD_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype>
class ReduceProdParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ReduceProdParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = OpParam::InputXFrom<GType>(inputs, *scope);
output_ = OpParam::OutFrom<GType>(outputs, *scope);
reduce_all_ = GetAttr<bool>("reduce_all", attrs);
keep_dim_ = GetAttr<bool>("keep_dim", attrs);
dim_ = GetAttr<std::vector<int>>("dim", attrs);
}
const GType *Input() const { return input_; }
GType *Output() const { return output_; }
bool isReduceAll() const { return reduce_all_; }
bool isKeepDim() const { return keep_dim_; }
const vector<int> getDim() const { return dim_; }
private:
GType *input_;
GType *output_;
bool reduce_all_;
bool keep_dim_;
std::vector<int> dim_;
};
DECLARE_KERNEL(ReduceProd, ReduceProdParam)
} // namespace operators
} // namespace paddle_mobile
#endif // REDUCE_PROD_OP
...@@ -30,12 +30,14 @@ class WhileParam : public OpParam { ...@@ -30,12 +30,14 @@ class WhileParam : public OpParam {
cond_ = cond_ =
OpParam::GetVarValue<framework::LoDTensor>("Condition", inputs, *scope); OpParam::GetVarValue<framework::LoDTensor>("Condition", inputs, *scope);
sub_block_ = OpParam::GetAttr<framework::BlockDesc *>("sub_block", attrs); sub_block_ = OpParam::GetAttr<framework::BlockDesc *>("sub_block", attrs);
is_test = OpParam::GetAttr<bool>("is_test", attrs);
} }
public: public:
const Scope *scope_; Scope *scope_;
framework::LoDTensor *cond_; framework::LoDTensor *cond_;
framework::BlockDesc *sub_block_; framework::BlockDesc *sub_block_;
bool is_test;
}; };
DECLARE_KERNEL(While, WhileParam); DECLARE_KERNEL(While, WhileParam);
......
...@@ -23,6 +23,9 @@ template <typename Dtype, typename T> ...@@ -23,6 +23,9 @@ template <typename Dtype, typename T>
void LodResetOp<Dtype, T>::InferShape() const { void LodResetOp<Dtype, T>::InferShape() const {
const auto &input_dims = this->param_.input_x_->dims(); const auto &input_dims = this->param_.input_x_->dims();
this->param_.output_->Resize(input_dims); this->param_.output_->Resize(input_dims);
if (this->param_.append) {
this->param_.output_->set_lod(this->param_.input_x_->lod());
}
} }
} // namespace operators } // namespace operators
......
...@@ -131,6 +131,11 @@ inline float Active(const float &x) { ...@@ -131,6 +131,11 @@ inline float Active(const float &x) {
return x; return x;
} }
template <ActivationType Act = IDENTITY>
inline int Active(const int &x) {
return x;
}
template <> template <>
inline float Active<RELU>(const float &x) { inline float Active<RELU>(const float &x) {
return std::max(x, 0.f); return std::max(x, 0.f);
......
...@@ -1330,6 +1330,55 @@ class FillConstantParam : public OpParam { ...@@ -1330,6 +1330,55 @@ class FillConstantParam : public OpParam {
}; };
#endif #endif
#ifdef FILL_CONSTANT_BATCH_SIZE_LIKE_OP
template <typename Dtype>
class FillConstantBatchSizeLikeParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FillConstantBatchSizeLikeParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = InputFrom<GType>(inputs, *scope);
out_var_ = OutVarFrom(outputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
dtype_ = GetAttr<int>("dtype", attrs);
shape_ = GetAttr<vector<int>>("shape", attrs);
value_ = GetAttr<float>("value", attrs);
input_dim_idx_ = GetAttr<int>("input_dim_idx", attrs);
output_dim_idx_ = GetAttr<int>("output_dim_idx", attrs);
}
Variable *OutVar() const { return out_var_; }
const GType *Input() const { return input_; }
GType *Out() const { return out_; }
const int &DataDtype() const { return dtype_; }
const vector<int> &Shape() const { return shape_; }
const float &Value() const { return value_; }
int InputDimIdx() const { return input_dim_idx_; }
int OutputDimIdx() const { return output_dim_idx_; }
private:
GType *input_;
Variable *out_var_;
GType *out_;
int dtype_;
vector<int> shape_;
float value_;
int input_dim_idx_;
int output_dim_idx_;
};
#endif
#ifdef TRANSPOSE_OP #ifdef TRANSPOSE_OP
template <typename Dtype> template <typename Dtype>
class TransposeParam : public OpParam { class TransposeParam : public OpParam {
...@@ -3236,6 +3285,7 @@ class LodResetParam : public OpParam { ...@@ -3236,6 +3285,7 @@ class LodResetParam : public OpParam {
} else { } else {
target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs); target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs);
} }
append = OpParam::GetAttr<bool>("append", attrs);
} }
public: public:
...@@ -3243,6 +3293,7 @@ class LodResetParam : public OpParam { ...@@ -3243,6 +3293,7 @@ class LodResetParam : public OpParam {
GType *input_y_; GType *input_y_;
GType *output_; GType *output_;
std::vector<int> target_lod_; std::vector<int> target_lod_;
bool append;
}; };
#endif // LOD_RESET_OP #endif // LOD_RESET_OP
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RANGE_OP
#include "operators/range_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void RangeOp<Dtype, T>::InferShape() const {
auto s_dims = this->param_.Start()->dims();
PADDLE_MOBILE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1),
"The shape of Input(Start) should be [1].");
auto e_dims = this->param_.End()->dims();
PADDLE_MOBILE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1),
"The shape of Input(End) should be [1].");
auto step_dims = this->param_.Step()->dims();
PADDLE_MOBILE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1),
"The shape of Input(Step) should be [1].");
this->param_.Output()->Resize({-1});
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(range, ops::RangeOp);
#endif
#endif // ASSIGN_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RANGE_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/range_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
DECLARE_OPERATOR(Range, RangeParam, RangeKernel);
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef REDUCE_PROD_OP
#include "operators/reduce_prod_op.h"
#include <algorithm>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ReduceProdOp<Dtype, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.Input() != nullptr,
"Input (X) of ReduceOp op should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Output() != nullptr,
"Output (Output) of ReduceOp op should not be null.");
auto x_dims = this->param_.Input()->dims();
auto x_rank = x_dims.size();
PADDLE_MOBILE_ENFORCE(x_rank <= 6,
"Tensors with rank at most 6 are supported.");
auto dims = this->param_.getDim();
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_MOBILE_ENFORCE(
dims[i] < x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
bool reduce_all = this->param_.isReduceAll();
bool keep_dim = this->param_.isKeepDim();
if (reduce_all) {
if (keep_dim)
this->param_.Output()->Resize(
framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
this->param_.Output()->Resize({1});
} else {
auto dims_vector = vectorize(x_dims);
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = framework::make_ddim(dims_vector);
this->param_.Output()->Resize(out_dims);
if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
this->param_.Output()->set_lod(this->param_.Input()->lod());
}
}
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(reduce_prod, ops::ReduceProdOp);
#endif
#endif // REDUCE_PROD_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef REDUCE_PROD_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/reduce_prod_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
DECLARE_OPERATOR(ReduceProd, ReduceProdParam, ReduceProdKernel);
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -56,10 +56,37 @@ void SliceOp<Dtype, T>::InferShape() const { ...@@ -56,10 +56,37 @@ void SliceOp<Dtype, T>::InferShape() const {
this->param_.output_->dims().size())) == this->param_.output_->dims().size())) ==
3, 3,
"op only support slice channel now"); "op only support slice channel now");
#else
PADDLE_MOBILE_ENFORCE(input->dims().size() - axes[0] == 3,
"op only support slice channel now");
#endif #endif
if (input->dims().size() >= 4) {
PADDLE_MOBILE_ENFORCE(input->dims().size() - axes[0] == 3,
"op only support slice channel now");
}
auto starts = this->param_.starts_;
auto ends = this->param_.ends_;
framework::DDim out_dims(input->dims());
PADDLE_MOBILE_ENFORCE(starts.size() == ends.size(),
"starts.size should equal ends.size");
PADDLE_MOBILE_ENFORCE(axes.size() == starts.size(),
"axes.size should equal starts.size");
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
// start = std::min(start, dim_value);
end = std::min(end, dim_value);
// start = std::min(start, end);
PADDLE_MOBILE_ENFORCE(end > start, "end should greater than start");
out_dims[axes[i]] = end - start;
}
}
output->Resize(out_dims);
if (axes[0] != 0) {
output->set_lod(input->lod());
}
} }
} // namespace operators } // namespace operators
......
...@@ -370,6 +370,12 @@ if(NOT FOUND_MATCH) ...@@ -370,6 +370,12 @@ if(NOT FOUND_MATCH)
set(ASSIGN_VALUE_OP ON) set(ASSIGN_VALUE_OP ON)
set(NEAREST_INTERP_OP ON) set(NEAREST_INTERP_OP ON)
set(LEAKY_RELU_OP ON) set(LEAKY_RELU_OP ON)
set(ASSIGN_OP ON)
set(CONDITIONAL_BLOCK_OP ON)
set(EQUAL_OP ON)
set(FILL_CONSTANT_BATCH_SIZE_LIKE_OP ON)
set(RANGE_OP ON)
set(REDUCE_PROD_OP ON)
endif() endif()
# option(BATCHNORM_OP "" ON) # option(BATCHNORM_OP "" ON)
...@@ -719,3 +725,21 @@ endif() ...@@ -719,3 +725,21 @@ endif()
if (EXP_OP) if (EXP_OP)
add_definitions(-DEXP_OP) add_definitions(-DEXP_OP)
endif () endif ()
if (ASSIGN_OP)
add_definitions(-DASSIGN_OP)
endif()
if (CONDITIONAL_BLOCK_OP)
add_definitions(-DCONDITIONAL_BLOCK_OP)
endif()
if (EQUAL_OP)
add_definitions(-DEQUAL_OP)
endif()
if (FILL_CONSTANT_BATCH_SIZE_LIKE_OP)
add_definitions(-DFILL_CONSTANT_BATCH_SIZE_LIKE_OP)
endif()
if (RANGE_OP)
add_definitions(-DRANGE_OP)
endif()
if (REDUCE_PROD_OP)
add_definitions(-DREDUCE_PROD_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册