提交 88f784a1 编写于 作者: L lijiancheng0614

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into develop

......@@ -62,6 +62,8 @@ const char *G_OP_TYPE_CRF = "crf_decoding";
const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp";
const char *G_OP_TYPE_FLATTEN = "flatten";
const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
......@@ -115,7 +117,8 @@ std::unordered_map<
{G_OP_TYPE_FLATTEN, {{"X"}, {"Out"}}},
{G_OP_TYPE_SHAPE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_CONV_TRANSPOSE, {{"Input"}, {"Output"}}},
{G_OP_TYPE_SUM, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -126,6 +126,8 @@ extern const char *G_OP_TYPE_REGION;
extern const char *G_OP_TYPE_FUSION_CONV_BN;
extern const char *G_OP_TYPE_CONV_TRANSPOSE;
extern const char *G_OP_TYPE_PRELU;
extern const char *G_OP_TYPE_SUM;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <initializer_list>
#include <vector>
#include "framework/tensor.h"
#include "framework/tensor_util.h"
namespace paddle_mobile {
namespace framework {
// Vector<T> implements the std::vector interface, and can get Data or
// MutableData from any place. The data will be synced implicitly inside.
template <typename T>
class Vector {
public:
using value_type = T;
// Default ctor. Create empty Vector
Vector() { InitEmpty(); }
// Fill vector with value. The vector size is `count`.
explicit Vector(size_t count, const T& value = T()) {
InitEmpty();
if (count != 0) {
resize(count);
T* ptr = begin();
for (size_t i = 0; i < count; ++i) {
ptr[i] = value;
}
}
}
// Ctor with init_list
Vector(std::initializer_list<T> init) {
if (init.size() == 0) {
InitEmpty();
} else {
InitByIter(init.size(), init.begin(), init.end());
}
}
// implicit cast from std::vector.
template <typename U>
Vector(const std::vector<U>& dat) { // NOLINT
if (dat.size() == 0) {
InitEmpty();
} else {
InitByIter(dat.size(), dat.begin(), dat.end());
}
}
// Copy ctor
Vector(const Vector<T>& other) { this->operator=(other); }
// Copy operator
Vector<T>& operator=(const Vector<T>& other) {
if (other.size() != 0) {
this->InitByIter(other.size(), other.begin(), other.end());
} else {
InitEmpty();
}
return *this;
}
// Move ctor
Vector(Vector<T>&& other) {
this->size_ = other.size_;
this->flag_ = other.flag_;
if (other.cuda_vec_.memory_size()) {
this->cuda_vec_.ShareDataWith(other.cuda_vec_);
}
if (other.cpu_vec_.memory_size()) {
this->cpu_vec_.ShareDataWith(other.cpu_vec_);
}
}
// CPU data access method. Mutable.
T& operator[](size_t i) {
MutableCPU();
return const_cast<T*>(cpu_vec_.data<T>())[i];
}
// CPU data access method. Immutable.
const T& operator[](size_t i) const {
// ImmutableCPU();
return cpu_vec_.data<T>()[i];
}
// std::vector iterator methods. Based on CPU data access method
size_t size() const { return size_; }
T* begin() { return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); }
T* end() {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
}
T& front() { return *begin(); }
T& back() {
auto it = end();
--it;
return *it;
}
const T* begin() const {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](0);
}
const T* end() const {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
}
const T* cbegin() const { return begin(); }
const T* cend() const { return end(); }
const T& back() const {
auto it = end();
--it;
return *it;
}
T* data() { return begin(); }
const T* data() const { return begin(); }
const T& front() const { return *begin(); }
// end of std::vector iterator methods
// assign this from iterator.
// NOTE: the iterator must support `end-begin`
template <typename Iter>
void assign(Iter begin, Iter end) {
InitByIter(end - begin, begin, end);
}
// push_back. If the previous capacity is not enough, the memory will
// double.
void push_back(T elem) {
if (size_ + 1 > capacity()) {
reserve((size_ + 1) << 1);
}
*end() = elem;
++size_;
}
// extend a vector by iterator.
// NOTE: the iterator must support end-begin
template <typename It>
void Extend(It begin, It end) {
size_t pre_size = size_;
resize(pre_size + (end - begin));
T* ptr = this->begin() + pre_size;
for (; begin < end; ++begin, ++ptr) {
*ptr = *begin;
}
}
// resize the vector
void resize(size_t size) {
if (size + 1 <= capacity()) {
size_ = size;
} else {
MutableCPU();
Tensor cpu_tensor;
T* ptr = cpu_tensor.mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(size)}));
const T* old_ptr =
cpu_vec_.memory_size() == 0 ? nullptr : cpu_vec_.data<T>();
if (old_ptr != nullptr) {
std::copy(old_ptr, old_ptr + size_, ptr);
}
size_ = size;
cpu_vec_.ShareDataWith(cpu_tensor);
}
}
// clear
void clear() {
size_ = 0;
flag_ = kDirty | kDataInCPU;
}
size_t capacity() const {
return cpu_vec_.memory_size() / SizeOfType(typeid(T));
}
// reserve data
void reserve(size_t size) {
size_t pre_size = size_;
resize(size);
resize(pre_size);
}
// implicit cast operator. Vector can be cast to std::vector implicitly.
operator std::vector<T>() const {
std::vector<T> result;
result.resize(size());
std::copy(begin(), end(), result.begin());
return result;
}
bool operator==(const Vector<T>& other) const {
if (size() != other.size()) return false;
auto it1 = cbegin();
auto it2 = other.cbegin();
for (; it1 < cend(); ++it1, ++it2) {
if (*it1 != *it2) {
return false;
}
}
return true;
}
private:
void InitEmpty() {
size_ = 0;
flag_ = kDataInCPU;
}
template <typename Iter>
void InitByIter(size_t size, Iter begin, Iter end) {
T* ptr = this->cpu_vec_.template mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(size)}));
for (size_t i = 0; i < size; ++i) {
*ptr++ = *begin++;
}
flag_ = kDataInCPU | kDirty;
size_ = size;
}
enum DataFlag {
kDataInCPU = 0x01,
kDataInCUDA = 0x02,
// kDirty means the data has been changed in one device.
kDirty = 0x10
};
void MutableCPU() { flag_ = kDirty | kDataInCPU; }
void UnsetFlag(int flag) const { flag_ &= ~flag; }
void SetFlag(int flag) const { flag_ |= flag; }
static T& EmptyDummy() {
static T dummy = T();
return dummy;
}
mutable int flag_;
mutable Tensor cpu_vec_;
mutable Tensor cuda_vec_;
size_t size_;
};
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/selected_rows.h"
namespace paddle_mobile {
namespace framework {
struct ReAllocateVisitor {
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
: tensor_(tensor), dims_(dims) {}
template <typename T>
void operator()() const {
framework::Tensor cpu_tensor;
T* ptr = cpu_tensor.mutable_data<T>(dims_);
const T* old_ptr =
tensor_->memory_size() == 0 ? nullptr : tensor_->data<T>();
if (old_ptr != nullptr) {
std::copy(old_ptr, old_ptr + tensor_->numel(), ptr);
}
tensor_->ShareDataWith(cpu_tensor);
}
framework::Tensor* tensor_;
framework::DDim dims_;
};
// TensorCopyVisitor(value, i * value_width, *value_.get(),
// index * value_width, value_width));
struct TensorCopyVisitor {
TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset,
const framework::Tensor src, int64_t src_offset,
int64_t size)
: dst_(dst),
dst_offset_(dst_offset),
src_(src),
src_offset_(src_offset),
size_(size) {}
template <typename T>
void operator()() const {
// TODO(Yancey1989): support other place
memory::Copy(dst_->mutable_data<T>() + dst_offset_,
src_.data<T>() + src_offset_, size_ * sizeof(T));
}
framework::Tensor* dst_;
int64_t dst_offset_;
framework::Tensor src_;
int64_t src_offset_;
int64_t size_;
};
bool SelectedRows::HasKey(int64_t key) const {
return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false
: true;
}
// std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
// framework::Tensor* value) const {
// PADDLE_MOBILE_ENFORCE(value->IsInitialized(),
// "The value tensor should be initialized.");
// std::vector<int64_t> non_keys;
// int64_t value_width = value_->numel() / value_->dims()[0];
// PADDLE_MOBILE_ENFORCE(value_width == value->numel() / value->dims()[0],
// "output tensor should have the same shape with table "
// "execpt the dims[0].");
//
// for (size_t i = 0; i < keys.size(); ++i) {
// int64_t index = Index(keys[i]);
// if (index == -1) {
// non_keys.push_back(keys[i]);
// } else {
// framework::VisitDataType(
// framework::ToDataType(value_->type()),
// TensorCopyVisitor(value, i * value_width, *value_.get(),
// index * value_width, value_width));
// }
// }
// return non_keys;
//}
// bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
// PADDLE_MOBILE_ENFORCE(value.IsInitialized(), "The value should be
// initialized."); if (value_->IsInitialized()) {
// PADDLE_MOBILE_ENFORCE(
// value.type() == value_->type(),
// "The type of the value should be same with the original value");
// }
// PADDLE_MOBILE_ENFORCE(value.dims()[0] == static_cast<size_t>(1),
// "The first dim of value should be 1.");
// auto index = Index(key);
// bool is_new_key = false;
// if (index == -1) {
// rows_.push_back(key);
// index = rows_.size() - 1;
// is_new_key = true;
// // whether need to resize the table
// if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
// auto dims = value_->dims();
// dims[0] = (dims[0] + 1) << 1;
// framework::VisitDataType(framework::ToDataType(value.type()),
// ReAllocateVisitor(value_.get(), dims));
// }
// }
//
// framework::VisitDataType(
// framework::ToDataType(value.type()),
// TensorCopyVisitor(value_.get(),
// index * value_->numel() / value_->dims()[0], value,
// static_cast<int64_t>(0), value.numel()));
// return is_new_key;
//}
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/lod_tensor.h"
#include "framework/tensor.h"
#include "memory/t_malloc.h"
#include "mixed_vector.h"
namespace paddle_mobile {
namespace framework {
class SelectedRows {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`
* number,
* and the value is a Tensor which the first dimension is 0.
* You can use the following interface to operate the sparse table, and you
* can find
* some detail information from the comments of each interface:
*
* HasKey(key), whether the sparse table has the specified key.
* Set(key, value), set a key-value pair into the sparse table.
* Get(keys, value*), get value by given key list and apply it to the given
* value pointer
* with the specified offset.
*
*/
public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new Tensor());
}
SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
}
// platform::Place place() const { return value_->place(); }
const Tensor& value() const { return *value_; }
Tensor* mutable_value() { return value_.get(); }
int64_t height() const { return height_; }
void set_height(int64_t height) { height_ = height; }
const Vector<int64_t>& rows() const { return rows_; }
Vector<int64_t>* mutable_rows() { return &rows_; }
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/*
* @brief wheter has the specified key in the table.
*
* @return true if the key is exists.
*/
bool HasKey(int64_t key) const;
/*
* @brief Get value by the key list, if the
*
* @return a list of keys which does not exists in table
*/
std::vector<int64_t> Get(std::vector<int64_t> keys,
framework::Tensor* tensor) const;
/*
* @brief Set a key-value pair into the table.
* This function will double the value memory if it's not engouth.
*
* @note:
* 1. The first dim of the value should be 1
* 2. The value should be initialized and the data type
* should be the same with the table.
*
* @return true if the key is a new one, otherwise false
*
*/
bool Set(int64_t key, const Tensor& value);
/*
* @brief Get the index of key in rows
*
* @return -1 if the key does not exists.
*/
int64_t Index(int64_t key) const {
auto it = std::find(rows_.begin(), rows_.end(), key);
if (it == rows_.end()) {
return static_cast<int64_t>(-1);
}
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_;
return make_ddim(dims);
}
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr};
int64_t height_;
};
/*
* Serialize/Desiralize SelectedRows to std::ostream
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
*/
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows);
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#include "elementwise_mul_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ElementwiseMulOp<Dtype, T>::InferShape() const {
auto x_dim = this->param_.InputX()->dims();
this->param_.Out()->Resize(x_dim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(elementwise_mul, ops::ElementwiseMulOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(elementwise_mul, ops::ElementwiseMulOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "kernel/elementwise_mul_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class ElementwiseMulOp : public framework::OperatorWithKernel<
DeviceType, ElementwiseMulParam<DeviceType>,
operators::ElementwiseMulKernel<DeviceType, T>> {
public:
ElementwiseMulOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ElementwiseMulParam<DeviceType>,
operators::ElementwiseMulKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ElementwiseMulParam<DeviceType>,
operators::ElementwiseMulKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#include "operators/kernel/elementwise_mul_kernel.h"
#include "operators/kernel/central-arm-func/elementwise_mul_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ElementwiseMulKernel<CPU, float>::Init(ElementwiseMulParam<CPU> *param) {
return true;
}
template <>
void ElementwiseMulKernel<CPU, float>::Compute(
const ElementwiseMulParam<CPU> &param) const {
ElementwiseMulCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SUM_OP
#include "operators/kernel/sum_kernel.h"
#include "operators/kernel/central-arm-func/sum_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool SumKernel<CPU, float>::Init(SumParam<CPU> *param) {
return true;
}
template <>
void SumKernel<CPU, float>::Compute(const SumParam<CPU> &param) const {
SumCompute<float>(param);
param.Out()->set_lod(param.Inputs()[0]->lod());
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#pragma once
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct MulFunctor {
inline T operator()(T a, T b) const { return a * b; }
};
template <typename P>
void ElementwiseMulCompute(const ElementwiseMulParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
Out->mutable_data<float>();
int axis = param.Axis();
ElementwiseComputeEx<MulFunctor<float>, float>(input_x, input_y, axis,
MulFunctor<float>(), Out);
}
template class ElementwiseMulKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SUM_OP
#pragma once
#include "operators/math/selected_rows_functor.h"
namespace paddle_mobile {
namespace operators {
using LoDTensorArray = std::vector<LoDTensor>;
template <typename P>
void SumCompute(const SumParam<CPU> &param) {
auto inputsvars = param.InputsVars();
int N = inputsvars.size();
auto *outvar = param.OutVar();
bool in_place = outvar == inputsvars[0];
DLOG << "11:";
if (outvar->IsType<framework::LoDTensor>()) {
auto *out = outvar->GetMutable<LoDTensor>();
if (!in_place) {
out->mutable_data<float>();
}
DLOG << "1:";
auto *outptr = out->data<float>();
// auto result = Flatten(*out);
if (!in_place) {
std::fill(out->data<float>(), out->data<float>() + out->numel(), 0);
}
math::SelectedRowsAddToTensor<float> functor;
for (int i = in_place ? 1 : 0; i < N; i++) {
if (inputsvars[i]->IsType<framework::LoDTensor>()) {
auto *in_t = inputsvars[i]->Get<framework::LoDTensor>();
auto *inptr = in_t->data<float>();
if (in_t->numel() == 0) {
continue;
}
for (int j = 0; j < out->numel(); ++j) {
outptr[j] = outptr[j] + inptr[j];
}
} else if (inputsvars[i]->IsType<framework::SelectedRows>()) {
auto *in_t = inputsvars[i]->Get<framework::SelectedRows>();
functor(*in_t, out);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"Variable type must be LoDTensor/SelectedRows.");
}
}
} else if (outvar->IsType<framework::SelectedRows>()) {
DLOG << "2:";
std::unique_ptr<framework::SelectedRows> in0;
if (in_place) {
// If is in_place, we store the input[0] to in0
auto *in_sel0 = inputsvars[0]->Get<SelectedRows>();
auto &rows = in_sel0->rows();
//#ifdef PADDLE_WITH_CUDA
// std::vector<int64_t> rows_in_cpu;
// rows_in_cpu.reserve(rows.size());
// for (auto item : rows) {
// rows_in_cpu.push_back(item);
// }
// in0.reset(new framework::SelectedRows(rows_in_cpu,
// in_sel0.height()));
//#else
in0.reset(new framework::SelectedRows(rows, in_sel0->height()));
//#endif
in0->mutable_value()->ShareDataWith(in_sel0->value());
}
auto get_selected_row = [&](size_t i) -> const SelectedRows & {
if (i == 0 && in0) {
return *in0.get();
} else {
return *(inputsvars[i]->Get<SelectedRows>());
}
};
auto *out = outvar->GetMutable<SelectedRows>();
out->mutable_rows()->clear();
auto *out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
auto in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim));
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
}
out_value->mutable_data<float>();
math::SelectedRowsAddTo<float> functor;
int64_t offset = 0;
for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_MOBILE_ENFORCE(out->height() == sel_row.height());
functor(sel_row, offset, out);
offset += sel_row.value().numel();
}
} else if (outvar->IsType<LoDTensorArray>()) {
DLOG << "3:";
auto &out_array = *outvar->GetMutable<LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < inputsvars.size(); ++i) {
PADDLE_MOBILE_ENFORCE(inputsvars[i]->IsType<LoDTensorArray>(),
"Only support all inputs are TensorArray");
auto *in_array = inputsvars[i]->Get<LoDTensorArray>();
for (size_t i = 0; i < in_array->size(); ++i) {
if ((*in_array)[i].numel() != 0) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
framework::TensorCopy((*in_array)[i], &out_array[i]);
out_array[i].set_lod((*in_array)[i].lod());
} else {
PADDLE_MOBILE_ENFORCE(out_array[i].lod() == (*in_array)[i].lod());
auto *inptr = (*in_array)[i].data<float>();
auto *outptr = out_array[i].data<float>();
for (int j = 0; j < (*in_array)[i].numel(); ++j) {
outptr[j] = inptr[j] + outptr[j];
}
}
}
}
}
} else {
DLOG << "2:";
if (outvar->IsType<framework::Tensor>()) {
DLOG << "3: ";
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Unexpected branch, output variable type is %s", outvar->Type().name());
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#pragma once
#include "framework/operator.h"
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class ElementwiseMulKernel
: public framework::OpKernelBase<DeviceType,
ElementwiseMulParam<DeviceType>> {
public:
void Compute(const ElementwiseMulParam<DeviceType> &param) const;
bool Init(ElementwiseMulParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SUM_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class SumKernel
: public framework::OpKernelBase<DeviceType, SumParam<DeviceType>> {
public:
void Compute(const SumParam<DeviceType> &param) const;
bool Init(SumParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -1667,7 +1667,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const int w_times = (out_w - 2) / 3;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
#pragma omp parallel for
#pragma omp parallel for
for (int j = 0; j < c; j++) {
const float *input_row_ptr;
float *output_row_ptr;
......@@ -1912,9 +1912,7 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
float32x4_t biasv = vld1q_dup_f32(bias_data);
for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m;
......@@ -1949,8 +1947,9 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, biasv);
if (if_bias) {
out0 = vaddq_f32(out0, biasv);
}
vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2);
......@@ -1960,16 +1959,18 @@ void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
}
for (int j = m; j < output_width; j++) {
output_data[i * output_width + j] =
input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
input_data[(2 * i) * input_width + 2 * j] * w11 +
input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
output_data[i * output_width + j] += *bias_data;
input_data[(2 * i) * input_width + 2 * j] * w00 +
input_data[(2 * i) * input_width + 2 * j + 1] * w01 +
input_data[(2 * i) * input_width + 2 * j + 2] * w02 +
input_data[(2 * i + 1) * input_width + 2 * j] * w10 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 +
input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 +
input_data[(2 * i + 2) * input_width + 2 * j] * w20 +
input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 +
input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22;
if (if_bias) {
output_data[i * output_width + j] += *bias_data;
}
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/selected_rows.h"
#define INLINE_FOR2(sizei, sizej) \
for (int64_t i = 0; i < sizei; i++) \
for (int64_t j = 0; j < sizej; j++)
namespace paddle_mobile {
namespace operators {
namespace math {
// SelectedRows + SelectedRows will simplely concat value and rows.
// The real computation happens in dealing with LoDTensor.
// template <typename T>
// struct SelectedRowsAdd {
// void operator()(
// const framework::SelectedRows& input1,
// const framework::SelectedRows& input2,
// framework::SelectedRows* output);
//};
//
// template <typename T>
// struct SelectedRowsAddTensor {
// void operator()(
// const framework::SelectedRows& input1,
// const framework::Tensor& input2, framework::Tensor* output);
//};
// input2 = input1 + input2
template <typename T>
struct SelectedRowsAddTo {
void operator()(const framework::SelectedRows& input1,
const int64_t input2_offset,
framework::SelectedRows* input2) {
auto in1_height = input1.height();
PADDLE_MOBILE_ENFORCE(in1_height == input2->height());
auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows());
auto& in1_value = input1.value();
auto* in2_value = input2->mutable_value();
// concat rows
in2_rows.Extend(in1_rows.begin(), in1_rows.end());
// auto in1_place = input1.place();
// PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
// auto in2_place = input2->place();
// PADDLE_ENFORCE(platform::is_cpu_place(in2_place));
auto* in1_data = in1_value.data<T>();
auto* in2_data = in2_value->data<T>();
memory::Copy(in2_data + input2_offset, in1_data,
in1_value.numel() * sizeof(T));
}
};
// input2 = input1 + input2
template <typename T>
struct SelectedRowsAddToTensor {
void operator()(const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_MOBILE_ENFORCE(in1_height == in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_MOBILE_ENFORCE(in1_row_numel == input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) {
for (int64_t j = 0; j < in1_row_numel; j++) {
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
}
}
}
};
// namespace scatter {
//// functors for manuplating SelectedRows data
// template <typename T>
// struct MergeAdd {
// // unary functor, merge by adding duplicated rows in
// // the input SelectedRows object.
// framework::SelectedRows operator()(
// const framework::SelectedRows& input);
//};
// template <typename T>
// struct Add {
// framework::SelectedRows operator()(
// const framework::SelectedRows& input1,
// const framework::SelectedRows& input2) {
// framework::SelectedRows out;
// out.set_rows(input1.rows());
// out.set_height(input1.height());
// out.mutable_value()->mutable_data<T>(input1.value().dims(),
// );
// auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
// auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
// auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
// e_out.device(*context.eigen_device()) = e_in1 + e_in2;
// return out;
// }
//};
// template <typename T>
// struct Mul {
// // multiply two SelectedRows
// framework::SelectedRows operator()(
// const framework::SelectedRows& input1,
// const framework::SelectedRows& input2) {
// framework::SelectedRows out;
// out.set_rows(input1.rows());
// out.set_height(input1.height());
// out.mutable_value()->mutable_data<T>(input1.value().dims()
// );
// auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
// auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
// auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
// e_out.device(*context.eigen_device()) = e_in1 * e_in2;
// return out;
// }
// // multiply scalar to SelectedRows
// framework::SelectedRows operator()(
// const framework::SelectedRows& input1,
// const T input2) {
// framework::SelectedRows out;
// out.set_rows(input1.rows());
// out.set_height(input1.height());
// out.mutable_value()->mutable_data<T>(input1.value().dims(),
// );
// auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
// auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
// e_out.device(*context.eigen_device()) = input2 * e_in1;
// return out;
// }
//};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor
template <typename T>
struct UpdateToTensor {
void operator()(const ScatterOps& op, const framework::SelectedRows& input1,
framework::Tensor* input2);
};
// namespace scatter
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -35,6 +35,7 @@ using framework::AttributeMap;
using framework::LoDTensor;
using framework::Scope;
using framework::Tensor;
using framework::Variable;
using std::string;
using std::vector;
......@@ -182,6 +183,11 @@ class OpParam {
return GetMultiVarValue<T>("X", inputs, scope);
}
static vector<Variable *> InputMultiVarsFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetMultiVar("X", inputs, scope);
}
template <typename T>
static T *OutputBatchGateFrom(const VariableNameMap &outputs,
const Scope &scope) {
......@@ -216,6 +222,11 @@ class OpParam {
return GetVarValue<T>("Output", outputs, scope);
}
static Variable *OutVarFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVar("Out", outputs, scope);
}
template <typename T>
static T *OutFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("Out", outputs, scope);
......@@ -286,6 +297,19 @@ class OpParam {
}
}
static Variable *GetVar(const string &key, const VariableNameMap &var_map,
const Scope &scope) {
PADDLE_MOBILE_ENFORCE(var_map.count(key) > 0,
"%s is not contained in var_map", key.c_str())
auto var_vec = var_map.at(key);
if (!var_vec.empty()) {
auto var = scope.FindVar(var_vec[0]);
return var;
} else {
return nullptr;
}
}
static std::string getkey(const string &key, const VariableNameMap &var_map,
int index) {
auto var_vec = var_map.at(key);
......@@ -319,6 +343,19 @@ class OpParam {
}
return var_res;
}
static vector<Variable *> GetMultiVar(const string &key,
const VariableNameMap &var_map,
const Scope &scope) {
auto var_vecs = var_map.at(key);
assert(var_vecs.size() > 1);
vector<Variable *> var_res;
for (auto &var_vec : var_vecs) {
auto var = scope.FindVar(var_vec);
var_res.push_back(var);
}
return var_res;
}
};
template <typename Dtype>
......@@ -405,6 +442,47 @@ class ElementwiseAddParam : OpParam {
#endif
};
#ifdef ELEMENTWISEMUL_OP
template <typename Dtype>
class ElementwiseMulParam : OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ElementwiseMulParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
input_y_ = InputYFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
}
const GType *InputX() const { return input_x_; }
const GType *InputY() const { return input_y_; }
GType *Out() const { return out_; }
const int &Axis() const { return axis_; }
private:
GType *input_x_;
GType *input_y_;
GType *out_;
int axis_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::EWMulArgs fpga_EW_mul_args;
public:
const fpga::EWMulArgs &FpgaArgs() const { return fpga_EW_mul_args; }
void SetFpgaArgs(const fpga::EWMulArgs &args) { fpga_EW_mul_args = args; }
#endif
};
#endif
#ifdef FUSION_ELEMENTWISEADDRELU_OP
template <typename Dtype>
using ElementwiseAddReluParam = ElementwiseAddParam<Dtype>;
......@@ -490,6 +568,46 @@ class ConcatParam : public OpParam {
};
#endif
#ifdef SUM_OP
template <typename Dtype>
class SumParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
SumParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
inputs_vars_ = InputMultiVarsFrom(inputs, scope);
out_var_ = OutVarFrom(outputs, scope);
inputs_ = InputMultiFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
}
vector<Variable *> InputsVars() const { return inputs_vars_; }
Variable *OutVar() const { return out_var_; }
vector<GType *> Inputs() const { return inputs_; }
GType *Out() const { return out_; }
private:
vector<Variable *> inputs_vars_;
Variable *out_var_;
vector<GType *> inputs_;
GType *out_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::SumArgs fpga_sum_args;
public:
const fpga::SumArgs &FpgaArgs() const { return fpga_sum_args; }
void SetFpgaArgs(const fpga::SumArgs &args) { fpga_sum_args = args; }
#endif
};
#endif
#ifdef LRN_OP
template <typename Dtype>
class LrnParam : public OpParam {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SUM_OP
#include <vector>
#include "operators/sum_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void SumOp<Dtype, T>::InferShape() const {
auto inputs = this->param_.Inputs();
const size_t n = inputs.size();
std::vector<DDim> inputs_dims;
inputs_dims.reserve(n);
for (int i = 0; i < n; i++) {
inputs_dims.push_back(inputs[i]->dims());
}
if (n == 1) {
DLOG << "Warning: sum op have only one input, "
"may waste memory";
}
framework::DDim in_dim({0});
for (auto& x_dim : inputs_dims) {
if (framework::product(x_dim) == 0) {
continue;
}
if (framework::product(in_dim) == 0) {
in_dim = x_dim;
} else {
PADDLE_MOBILE_ENFORCE(in_dim == x_dim,
"input tensors must have same shape");
}
}
this->param_.Out()->Resize(in_dim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sum, ops::SumOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(sum, ops::ConcatOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(sum, ops::ConcatOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SUM_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sum_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class SumOp : public framework::OperatorWithKernel<
DeviceType, SumParam<DeviceType>,
operators::SumKernel<DeviceType, T>> {
public:
SumOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, SumParam<DeviceType>,
operators::SumKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, SumParam<DeviceType>,
operators::SumKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -212,6 +212,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fc-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-sum-op operators/test_sum_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sum-op paddle-mobile)
# test quantize op
ADD_EXECUTABLE(test-quantize-op operators/test_quantize_op.cpp test_helper.h test_include.h)
target_link_libraries(test-quantize-op paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/sum_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestSumOp {
public:
explicit TestSumOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "sum" && op->Input("X")[0] == "fc_2.tmp_0") {
DLOG << " sum attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
std::shared_ptr<operators::SumOp<Dtype, float>> lrn =
std::make_shared<operators::SumOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
std::shared_ptr<Tensor> predict_bn(const Tensor &t1, const Tensor &t2) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("fc_2.tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>();
tensor_x1->ShareDataWith(t1);
Variable *x2_feed_value = scope->Var("fc_2.tmp_1");
auto tensor_x2 = x2_feed_value->GetMutable<LoDTensor>();
tensor_x2->ShareDataWith(t2);
Variable *output = scope->Var("fc_2.tmp_2");
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({2, 96});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_bn(t1, t2, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_bn(const Tensor &t1, const Tensor &t2, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestSumOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run Sum Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_eng) + "/model",
std::string(g_eng) + "/params");
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {2, 96}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {2, 96}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>();
paddle_mobile::framework::TestSumOp<paddle_mobile::CPU> testSumOp(program);
auto output_sum = testSumOp.predict_bn(inputx1, inputx2);
auto *output_sum_ptr = output_sum->data<float>();
DLOG << "input1 44: " << inputx1_ptr[44];
DLOG << "input2 44: " << inputx2_ptr[44];
DLOG << "out 44 :" << output_sum_ptr[44];
return 0;
}
......@@ -27,6 +27,7 @@ limitations under the License. */
static const char *g_ocr = "../models/ocr";
static const char *g_mobilenet_ssd = "../models/mobilenet+ssd";
static const char *g_genet_combine = "../models/enet";
static const char *g_eng = "../models/eng_20conv_1_9_fc";
static const char *g_mobilenet_ssd_gesture = "../models/mobilenet+ssd_gesture";
static const char *g_mobilenet_combined = "../models/mobilenet_combine";
static const char *g_googlenetv1_combined = "../models/googlenetv1_combine";
......@@ -51,6 +52,7 @@ static const char *g_test_image_1x3x224x224_banana =
static const char *g_test_image_desktop_1_3_416_416_nchw_float =
"../images/in_put_1_3_416_416_2";
static const char *g_hand = "../images/hand_image";
static const char *g_moto = "../images/moto_300x300_float";
static const char *g_imgfssd_ar = "../images/test_image_ssd_ar";
static const char *g_imgfssd_ar1 = "../images/003_0001.txt";
static const char *g_img = "../images/img.bin";
......
......@@ -33,6 +33,7 @@ if (CON GREATER -1)
set(POOL_OP ON)
set(RESHAPE_OP ON)
set(FUSION_CONVADDBNRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
set(FUSION_CONVADD_OP ON)
set(FOUND_MATCH ON)
......@@ -220,6 +221,8 @@ if(NOT FOUND_MATCH)
set(SPLIT_OP ON)
set(FLATTEN_OP ON)
set(SHAPE_OP ON)
set(ELEMENTWISEMUL_OP ON)
set(SUM_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -388,3 +391,11 @@ endif()
if (SHAPE_OP)
add_definitions(-DSHAPE_OP)
endif()
if (ELEMENTWISEMUL_OP)
add_definitions(-DELEMENTWISEMUL_OP)
endif()
if (SUM_OP)
add_definitions(-DSUM_OP)
endif()
......@@ -45,13 +45,13 @@ def combine_bgrs_nchw(bgrs, means_b_g_r, scale, channel_type=ChannelType.BGR):
print '------------------'
print bgrs_float_array[0]
print bgrs_float_array[416 * 416 * 2 + 416 * 2 + 2]
print bgrs_float_array[224 * 224 * 2 + 224 * 2 + 2]
# for i in range(0, 9):
# print'bs %d' % i
# print bs[i] / 255.
print bs[416 * 2 + 2] / 255.
print bs[224 * 2 + 2] / 255.
print '--------------combine_bgrs_nchw-----------------end'
return bgrs_float_array
......@@ -64,6 +64,6 @@ def combine_bgrs_nchw(bgrs, means_b_g_r, scale, channel_type=ChannelType.BGR):
# cv2.waitKey(0)
bgrs = tools.resize_take_rgbs('datas/newyolo.jpg', (416, 416, 3))
bgrs = tools.resize_take_rgbs('datas/jpgs/0000_0.9834-148196_82452-0ad4b83ec6bc0f9c5f28101539267054.jpg_p0_0.126571263346.jpg', (224, 224, 3))
array = combine_bgrs_nchw(bgrs, (0, 0, 0), 1. / 255, ChannelType.RGB)
tools.save_to_file('datas/desktop_1_3_416_416_nchw_float', array)
tools.save_to_file('datas/desktop_1_3_224_224_nchw_float', array)
......@@ -15,11 +15,11 @@ from array import array
# image.resize(shape_h_w)
data = np.fromfile('datas/img.res')
data = np.fromfile('/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/imagetools/datas/jpgs2/0000_0.9834-148196_82452-0ad4b83ec6bc0f9c5f28101539267054.jpg_p0_0.126571263346.jpg.input.npfile','f')
print data.size
print data[0]
print data
data.reshape(1, 3, 416, 416)
data.reshape(1, 3, 224, 224)
out_array = array('f')
print'--------------------'
print data.size
......@@ -27,12 +27,12 @@ print data[0]
print '如果是nhwc --------'
# rgb rgb rgb rgb rgb
print data[416 * 3 * 2 + 3 * 2 + 2]
print data[224 * 3 * 2 + 3 * 2 + 2]
# print data[2]
print '如果是nchw --------'
# rgb rgb rgb rgb rgb
print data[416 * 416 * 2 + 416 * 2 + 2]
print data[224 * 224 * 2 + 224 * 2 + 2]
# print data[2]
# 明明是nchw
......@@ -42,6 +42,8 @@ for i in range(0, data.size):
print len(out_array)
print out_array[416 * 416 * 2 + 416 * 2 + 2]
print out_array[224 * 224 * 2 + 224 * 2 + 2]
# print out_array
tools.save_to_file('datas/in_put_1_3_416_416_2', out_array)
tools.save_to_file('datas/in_put_1_3_224_224_nchw', out_array)
......@@ -77,6 +77,14 @@ fusion_conv_add_attrs_dict = {
'strides': 'stride',
'groups': 'group'
}
# fluid attr key --- mdl params key
pool2d_attrs_dict = {
'global_pooling': 'global_pooling',
'pooling_type': 'type'
}
# fluid attr key --- mdl params key
fluid_attrs_type_dict = {
'paddings': 0,
......
# coding=utf-8
import json
import os
......@@ -12,13 +13,25 @@ def load_mdl(mdl_json_path):
return json.load(f)
def create_if_not_exit(target_dir):
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir, 0777)
class Converter:
'convert mdlmodel to fluidmodel'
def __init__(self, base_dir, mdl_json_path):
print 'base_dir: ' + base_dir
self.mdl_json_path = base_dir + mdl_json_path
self.base_dir = base_dir
print mdl_json_path
self.source_weights_dir = self.base_dir + 'datas/sourcemodels/source_weights/'
self.target_weight_dir = self.base_dir + 'datas/target/target_weights/'
create_if_not_exit(self.target_weight_dir)
self.mdl_json = load_mdl(self.mdl_json_path)
self.program_desc = framework_pb2.ProgramDesc()
self.weight_list_ = []
......@@ -41,16 +54,18 @@ class Converter:
print 'convert end.....'
desc_serialize_to_string = self.program_desc.SerializeToString()
outputmodel_ = self.base_dir + 'datas/target/outputmodel/'
if os.path.exists(outputmodel_):
shutil.rmtree(outputmodel_)
os.makedirs(outputmodel_, 0777)
# todo copy weight files
# if os.path.exists(outputmodel_):
# shutil.rmtree(outputmodel_)
# shutil.copytree('yolo/datas/multiobjects/float32s_nchw_with_head/', 'mobilenet/datas/target/outputmodel/')
outputmodel_dir = self.base_dir + 'datas/target/mobilenet_classfication/'
if os.path.exists(outputmodel_dir):
shutil.rmtree(outputmodel_dir)
os.makedirs(outputmodel_dir, 0777)
f = open(outputmodel_ + "__model__", "wb")
if os.path.exists(outputmodel_dir):
shutil.rmtree(outputmodel_dir)
# create_if_not_exit(outputmodel_dir)
shutil.copytree(self.target_weight_dir, outputmodel_dir)
f = open(outputmodel_dir + "__model__", "wb")
f.write(desc_serialize_to_string)
f.close()
......@@ -63,26 +78,30 @@ class Converter:
layers_ = self.mdl_json['layer']
for layer in layers_:
desc_ops_add = block_desc.ops.add()
# print layer
# for i in layer:
# print i
if 'name' in layer:
l_name = layer['name']
if 'type' in layer:
self.package_ops_type(desc_ops_add, layer)
if layer['type'] == 'SoftmaxLayer':
pass
else:
desc_ops_add = block_desc.ops.add()
# print layer
# for i in layer:
# print i
if 'name' in layer:
l_name = layer['name']
if 'type' in layer:
self.package_ops_type(desc_ops_add, layer)
if 'weight' in layer:
self.package_ops_weight2inputs(desc_ops_add, layer)
if 'weight' in layer:
self.package_ops_weight2inputs(desc_ops_add, layer)
if 'output' in layer:
self.package_ops_outputs(desc_ops_add, layer)
if 'output' in layer:
self.package_ops_outputs(desc_ops_add, layer)
if 'input' in layer:
self.package_ops_inputs(desc_ops_add, layer)
if 'input' in layer:
self.package_ops_inputs(desc_ops_add, layer)
self.package_ops_attrs(desc_ops_add, layer)
self.package_ops_attrs(desc_ops_add, layer)
self.add_op_fetch(block_desc)
......@@ -105,7 +124,8 @@ class Converter:
desc_ops_add = block_desc.ops.add()
inputs_add = desc_ops_add.inputs.add()
inputs_add.parameter = 'X'
inputs_add.arguments.append('conv_pred_87')
# todo pick last layer --> op output
inputs_add.arguments.append('fc7')
desc_ops_add.type = 'fetch'
outputs_add = desc_ops_add.outputs.add()
outputs_add.parameter = 'Out'
......@@ -129,6 +149,128 @@ class Converter:
# boolean
attrs_add.type = 6
attrs_add.b = 0
elif desc_ops_add.type == types.op_fluid_pooling:
Converter.pack_pooling_attr(desc_ops_add, layer)
pass
elif desc_ops_add.type == types.op_fluid_softmax:
pass
@staticmethod
def pack_pooling_attr(desc_ops_add, layer):
print layer
l_params = layer['param']
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'use_mkldnn'
# boolean
attrs_add.type = 6
attrs_add.b = 0
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'use_cudnn'
# boolean
attrs_add.type = 6
attrs_add.b = 1
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'paddings'
# ints
attrs_add.type = 3
attrs_add.ints.append(0)
attrs_add.ints.append(0)
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'strides'
# ints
attrs_add.type = 3
attrs_add.ints.append(1)
attrs_add.ints.append(1)
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'global_pooling'
# boolean
attrs_add.type = 6
attrs_add.b = (l_params[types.pool2d_attrs_dict.get('global_pooling')])
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'pooling_type'
# 2-->STRING
attrs_add.type = 2
# 注意这里 avg but mdl is ave
attrs_add.s = l_params[types.pool2d_attrs_dict.get('pooling_type')]
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'ceil_mode'
# boolean
attrs_add.type = 6
attrs_add.b = 1
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'ksize'
# ints
attrs_add.type = 3
attrs_add.ints.append(7)
attrs_add.ints.append(7)
# type: "pool2d"
# attrs
# {
# name: "use_mkldnn"
# type: BOOLEAN
# b: false
# }
# attrs
# {
# name: "ceil_mode"
# type: BOOLEAN
# b: true
# }
# attrs
# {
# name: "use_cudnn"
# type: BOOLEAN
# b: true
# }
# attrs
# {
# name: "paddings"
# type: INTS
# ints: 0
# ints: 0
# }
# attrs
# {
# name: "strides"
# type: INTS
# ints: 1
# ints: 1
# }
# attrs
# {
# name: "global_pooling"
# type: BOOLEAN
# b: false
# }
# attrs
# {
# name: "data_format"
# type: STRING
# s: "AnyLayout"
# }
# attrs
# {
# name: "ksize"
# type: INTS
# ints: 7
# ints: 7
# }
# attrs
# {
# name: "pooling_type"
# type: STRING
# s: "avg"
# }
# is_target: false
@staticmethod
def pack_fusion_conv_add_attr(desc_ops_add, layer):
......@@ -181,6 +323,13 @@ class Converter:
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('paddings')])
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('paddings')])
# attrs_add = desc_ops_add.attrs.add()
# attrs_add.name = 'paddings'
# # ints
# attrs_add.type = 3
# attrs_add.ints.append(0)
# attrs_add.ints.append(0)
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'strides'
# ints
......@@ -188,6 +337,13 @@ class Converter:
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('strides')])
attrs_add.ints.append(l_params[types.fusion_conv_add_attrs_dict.get('strides')])
# attrs_add = desc_ops_add.attrs.add()
# attrs_add.name = 'strides'
# # ints
# attrs_add.type = 3
# attrs_add.ints.append(6)
# attrs_add.ints.append(6)
attrs_add = desc_ops_add.attrs.add()
attrs_add.name = 'groups'
# int
......@@ -232,8 +388,8 @@ class Converter:
# print o
outputs_add = desc_ops_add.outputs.add()
dict = types.op_io_dict.get(desc_ops_add.type)
print 'desc_ops_add.type: ' + desc_ops_add.type
print dict
# print 'desc_ops_add.type: ' + desc_ops_add.type
# print dict
outputs_add.parameter = dict.get(types.mdl_outputs_key)
outputs_add.arguments.append(o)
......@@ -305,7 +461,7 @@ class Converter:
# issues in mdl model filter swich n and c
if j in self.deepwise_weight_list_ and len(dims_of_matrix) == 4:
print j
print "deep wise issue fit: " + j
tensor.dims.append(dims_of_matrix[1])
tensor.dims.append(dims_of_matrix[0])
tensor.dims.append(dims_of_matrix[2])
......@@ -320,6 +476,12 @@ class Converter:
vars_add.persistable = 1
dims_size = len(dims_of_matrix)
# print dims_size
# print 'weight name : ' + j
Swichter().copy_add_head(
self.source_weights_dir + j + '.bin',
self.target_weight_dir + j
)
# if dims_size == 4:
# # convert weight from nhwc to nchw
# Swichter().nhwc2nchw_one_slice_add_head(
......@@ -341,7 +503,7 @@ class Converter:
vars_add.persistable = 0
mdl_path = "datas/sourcemodels/cls231_0802/mobileNetModel.json"
mdl_path = "datas/sourcemodels/source_profile/mobileNetModel.json"
base_dir = "/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/modeltools/mobilenet/"
converter = Converter(base_dir, mdl_path)
converter.convert()
import os
import shutil
from array import array
......@@ -58,7 +60,7 @@ class Swichter:
to_file = open(to_file_name, "wb")
tmp = tmp_file.read()
head = self.read_head('yolo/datas/yolo/conv1_biases')
head = self.read_head('yolo/datas/yolo/head')
to_file.write(head)
to_file.write(tmp)
tmp_file.close()
......@@ -72,12 +74,14 @@ class Swichter:
# print read
return read
def copy_add_head(self, from_file_name, to_file_name, tmp_file_name):
def copy_add_head(self, from_file_name, to_file_name):
from_file = open(from_file_name, "rb")
to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases')
head = self.read_head(
'/Users/xiebaiyuan/PaddleProject/paddle-mobile/tools/python/modeltools/mobilenet/datas/sourcemodels/head/head')
to_file.write(head)
to_file.write(from_file.read())
from_file.close()
......@@ -96,7 +100,7 @@ class Swichter:
to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases')
head = self.read_head('yolo/datas/yolo/head')
to_file.write(head)
to_file.write(read)
from_file.close()
......@@ -110,6 +114,6 @@ class Swichter:
# 32,
# 3, 3, 3)
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/conv1_biases')
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/head')
# Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '')
......@@ -58,7 +58,7 @@ class Swichter:
to_file = open(to_file_name, "wb")
tmp = tmp_file.read()
head = self.read_head('yolo/datas/yolo/conv1_biases')
head = self.read_head('yolo/datas/yolo/head')
to_file.write(head)
to_file.write(tmp)
tmp_file.close()
......@@ -77,7 +77,7 @@ class Swichter:
to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases')
head = self.read_head('yolo/datas/yolo/head')
to_file.write(head)
to_file.write(from_file.read())
from_file.close()
......@@ -96,7 +96,7 @@ class Swichter:
to_file = open(to_file_name, "wb")
# tmp_file = open(tmp_file_name, "wb")
head = self.read_head('yolo/datas/yolo/conv1_biases')
head = self.read_head('yolo/datas/yolo/head')
to_file.write(head)
to_file.write(read)
from_file.close()
......@@ -110,6 +110,6 @@ class Swichter:
# 32,
# 3, 3, 3)
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/conv1_biases')
# Swichter().read_head('/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/modeltools/yolo/head')
# Swichter().copy_add_head('datas/model.0.0.weight', 'datas/conv1_0', '')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册