提交 45de1e84 编写于 作者: 朔-望's avatar 朔-望 提交者: GitHub

Merge pull request #198 from allonli/develop

Add clang-tidy and clang-format hook
...@@ -3,5 +3,4 @@ Language: Cpp ...@@ -3,5 +3,4 @@ Language: Cpp
BasedOnStyle: LLVM BasedOnStyle: LLVM
Standard: Cpp11 Standard: Cpp11
IndentWidth: 4 IndentWidth: 4
NamespaceIndentation: All
... ...
Checks: >
*
-android-*
-bugprone-bool-pointer-implicit-conversion
-cert-env33-c
-cert-dcl50-cpp
-cert-dcl59-cpp
-cppcoreguidelines-*
-fuchsia-*
-google-*
google-default-arguments
google-explicit-constructor
google-runtime-member-string-references
google-runtime-operator
-hicpp-braces-around-statements
-hicpp-named-parameter
-hicpp-no-array-decay
-hicpp-no-assembler
-hicpp-no-malloc
-hicpp-function-size
-hicpp-special-member-functions
-hicpp-vararg
-llvm-*
-objc-*
-readability-else-after-return
-readability-implicit-bool-conversion
-readability-named-parameter
-readability-simplify-boolean-expr
-readability-braces-around-statements
-readability-identifier-naming
-readability-function-size
-readability-redundant-member-init
-misc-bool-pointer-implicit-conversion
-misc-definitions-in-headers
-misc-unused-alias-decls
-misc-unused-parameters
-misc-unused-using-decls
-modernize-use-using
-modernize-use-default-member-init
-clang-diagnostic-*
-clang-analyzer-*
WarningsAsErrors: ''
HeaderFilterRegex: ''
AnalyzeTemporaryDtors: false
FormatStyle: none
User: allonli
CheckOptions:
- key: google-readability-braces-around-statements.ShortStatementLines
value: '1'
- key: google-readability-function-size.StatementThreshold
value: '800'
- key: google-readability-namespace-comments.ShortNamespaceLines
value: '10'
- key: google-readability-namespace-comments.SpacesBeforeComments
value: '2'
- key: modernize-loop-convert.MaxCopySize
value: '16'
- key: modernize-loop-convert.MinConfidence
value: reasonable
- key: modernize-loop-convert.NamingStyle
value: CamelCase
- key: modernize-pass-by-value.IncludeStyle
value: llvm
- key: modernize-replace-auto-ptr.IncludeStyle
value: llvm
- key: modernize-use-nullptr.NullMacros
value: 'NULL'
...@@ -22,13 +22,22 @@ repos: ...@@ -22,13 +22,22 @@ repos:
- repo: local - repo: local
hooks: hooks:
- id: clang-format-with-version-check - id: clang-format
name: clang-format name: clang-format
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./tools/pre-commit.hooks/.clang_format.hook -i entry: bash ./tools/pre-commit.hooks/.clang-format.hook -i
language: system language: system
files: (src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ files: \.(c|cc|cxx|cpp|h|hpp|hxx)$
- repo: local
hooks:
- id: clang-tidy
name: clang-tidy
description: Check C++ code style using clang-tidy.
entry: bash ./tools/pre-commit.hooks/.clang-tidy.hook -i
language: system
files: (src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$
#
#- repo: local #- repo: local
# hooks: # hooks:
# - id: copyright_checker # - id: copyright_checker
......
...@@ -6,26 +6,54 @@ dist: trusty ...@@ -6,26 +6,54 @@ dist: trusty
os: os:
- linux - linux
env: env:
- JOB=check_style global:
- CMAKE_URL=https://cmake.org/files/v3.11/cmake-3.11.1-Linux-x86_64.tar.gz
addons: addons:
apt: apt:
sources:
- llvm-toolchain-trusty-6.0
- ubuntu-toolchain-r-test
packages: packages:
- git - git
- python - python
- python-pip - python-pip
- python2.7-dev - python2.7-dev
- clang-format-3.8 - libc6-i386
- clang-6.0
- libclang-6.0
- llvm-6.0
- llvm-6.0-dev
- curl
compiler:
- clang
before_install: before_install:
- sudo pip install -U virtualenv pre-commit pip - sudo pip install -U virtualenv pre-commit pip
# Download and install recent cmake
- |
if [[ ${TRAVIS_OS_NAME} == "linux" ]]; then
CMAKE_URL=${CMAKE_URL}
mkdir -p ${DEPS_DIR}/cmake
travis_retry wget --no-check-certificate --quiet -O - ${CMAKE_URL} | tar --strip-components=1 -xz -C ${DEPS_DIR}/cmake
export PATH=${DEPS_DIR}/cmake/bin:${PATH}
fi
script: #install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi # - if [ "$CXX" = "g++" ]; then export CXX="g++-5" CC="gcc-5"; fi
# - if [ "$CXX" = "clang++" ]; then export CXX="clang++-6.0" CC="clang-6.0"; fi
before_script:
- | - |
echo "cmake generate compile_commands.json for clang-tidy"
ls -l -a
clang-tidy -version
clang-format -version
script:
- |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
- | - |
timeout 600 .travis/${JOB}.sh # 10min timeout timeout 600 .travis/pre-commit-job.sh # 10min timeout
RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else exit 1; fi; RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else exit 1; fi;
notifications: notifications:
......
#!/bin/bash #!/bin/bash
function abort(){ function abort(){
echo "Your change doesn't follow PaddlePaddle's code style" 1>&2 echo "Your change doesn't follow Paddle-Moible's code style" 1>&2
echo "Please use pre-commit to auto-format your code." 1>&2 echo "Please use pre-commit to auto-format your code." 1>&2
exit 1 exit 1
} }
...@@ -11,7 +11,6 @@ cd `dirname $0` ...@@ -11,7 +11,6 @@ cd `dirname $0`
cd .. cd ..
export PATH=/usr/bin:$PATH export PATH=/usr/bin:$PATH
pre-commit install pre-commit install
clang-format --version
if ! pre-commit run -a ; then if ! pre-commit run -a ; then
ls -lh ls -lh
......
...@@ -5,7 +5,8 @@ add_definitions(-DPADDLE_MOBILE_DEBUG="true") ...@@ -5,7 +5,8 @@ add_definitions(-DPADDLE_MOBILE_DEBUG="true")
set(CMAKE_BUILD_TYPE RelWithDebInfo) set(CMAKE_BUILD_TYPE RelWithDebInfo)
set(CMAKE_VERBOSE_MAKEFILE on) set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY build) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY build)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY build) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY build)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY build) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY build)
......
#!/bin/bash #!/bin/bash
build_for_linux() { build_for_linux() {
echo "linux" if [ ! `which brew` ]; then
echo "building failed! homebrew not found, please install homebrew."
return
fi
if [ ! `which cmake` ]; then
echo "installing cmake."
brew install cmake
if [ ! $? ]; then
echo "cmake install failed."
return
fi
fi
PLATFORM="x86"
MODE="Release"
CXX_FLAGS="-std=c++11 -O3 -s"
BUILD_DIR=build/release/"${PLATFORM}"
mkdir -p ${BUILD_DIR}/build
mkdir -p ${BUILD_DIR}/test
cp -r test/models ${BUILD_DIR}/test/models
cmake . \
-B"${BUILD_DIR}" \
-DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DIS_MAC=true
cd ${BUILD_DIR}
make -j 8
} }
build_for_mac() { build_for_mac() {
......
此差异已折叠。
...@@ -27,146 +27,145 @@ SOFTWARE. ...@@ -27,146 +27,145 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
enum LogLevel { enum LogLevel {
kNO_LOG, kNO_LOG,
kLOG_ERROR, kLOG_ERROR,
kLOG_WARNING, kLOG_WARNING,
kLOG_INFO, kLOG_INFO,
kLOG_DEBUG, kLOG_DEBUG,
kLOG_DEBUG1, kLOG_DEBUG1,
kLOG_DEBUG2, kLOG_DEBUG2,
kLOG_DEBUG3, kLOG_DEBUG3,
kLOG_DEBUG4 kLOG_DEBUG4
}; };
// log level // log level
static LogLevel log_level = kLOG_DEBUG4; static LogLevel log_level = kLOG_DEBUG4;
static std::vector<std::string> logs{"NO", "ERROR ", "WARNING", static std::vector<std::string> logs{"NO", "ERROR ", "WARNING",
"INFO ", "DEBUG ", "DEBUG1 ", "INFO ", "DEBUG ", "DEBUG1 ",
"DEBUG2 ", "DEBUG3 ", "DEBUG4 "}; "DEBUG2 ", "DEBUG3 ", "DEBUG4 "};
struct ToLog; struct ToLog;
struct Print; struct Print;
struct Print { struct Print {
friend struct ToLog; friend struct ToLog;
template <typename T> Print &operator<<(T const &value) { template <typename T> Print &operator<<(T const &value) {
buffer_ << value; buffer_ << value;
return *this; return *this;
}
private:
void print(LogLevel level) {
buffer_ << std::endl;
if (level == kLOG_ERROR) {
std::cerr << buffer_.str();
} else {
std::cout << buffer_.str();
} }
}
private: std::ostringstream buffer_;
void print(LogLevel level) { };
buffer_ << std::endl;
if (level == kLOG_ERROR) { struct ToLog {
std::cerr << buffer_.str(); ToLog(LogLevel level = kLOG_DEBUG, const std::string &info = "")
} else { : level_(level) {
std::cout << buffer_.str(); unsigned blanks =
} (unsigned)(level > kLOG_DEBUG ? (level - kLOG_DEBUG) * 4 : 1);
} printer_ << logs[level] << " " << info << ":"
std::ostringstream buffer_; << std::string(blanks, ' ');
}; }
struct ToLog { template <typename T> ToLog &operator<<(T const &value) {
ToLog(LogLevel level = kLOG_DEBUG, const std::string &info = "") printer_ << value;
: level_(level) { return *this;
unsigned blanks = }
(unsigned)(level > kLOG_DEBUG ? (level - kLOG_DEBUG) * 4 : 1);
printer_ << logs[level] << " " << info << ":" ~ToLog() { printer_.print(level_); }
<< std::string(blanks, ' ');
} private:
LogLevel level_;
template <typename T> ToLog &operator<<(T const &value) { Print printer_;
printer_ << value; };
return *this;
}
~ToLog() { printer_.print(level_); }
private:
LogLevel level_;
Print printer_;
};
#define LOG(level) \ #define LOG(level) \
if (level > paddle_mobile::log_level) { \ if (level > paddle_mobile::log_level) { \
} else \ } else \
paddle_mobile::ToLog( \ paddle_mobile::ToLog( \
level, \ level, (std::stringstream() \
(std::stringstream() \ << "[file: " \
<< "[file: " \ << (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \ : __FILE__) \
<< "] [line: " << __LINE__ << "] ") \ << "] [line: " << __LINE__ << "] ") \
.str()) .str())
#define DLOG \ #define DLOG \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \ if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \ } else \
paddle_mobile::ToLog( \ paddle_mobile::ToLog( \
paddle_mobile::kLOG_DEBUG, \ paddle_mobile::kLOG_DEBUG, \
(std::stringstream() \ (std::stringstream() \
<< "[file: " \ << "[file: " \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \ << (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
<< "] [line: " << __LINE__ << "] ") \ : __FILE__) \
.str()) << "] [line: " << __LINE__ << "] ") \
} .str())
} // namespace paddle_mobile
#define LOGF(level, format, ...) \ #define LOGF(level, format, ...) \
if (level > paddle_mobile::log_level) { \ if (level > paddle_mobile::log_level) { \
} else \ } else \
printf(format, ##__VA_ARGS__) printf(format, ##__VA_ARGS__)
#define DLOGF(format, ...) \ #define DLOGF(format, ...) \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \ if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \ } else \
printf(format, ##__VA_ARGS__) printf(format, ##__VA_ARGS__)
#else #else
namespace paddle_mobile { namespace paddle_mobile {
enum LogLevel { enum LogLevel {
kNO_LOG, kNO_LOG,
kLOG_ERROR, kLOG_ERROR,
kLOG_WARNING, kLOG_WARNING,
kLOG_INFO, kLOG_INFO,
kLOG_DEBUG, kLOG_DEBUG,
kLOG_DEBUG1, kLOG_DEBUG1,
kLOG_DEBUG2, kLOG_DEBUG2,
kLOG_DEBUG3, kLOG_DEBUG3,
kLOG_DEBUG4 kLOG_DEBUG4
}; };
struct ToLog; struct ToLog;
struct Print { struct Print {
friend struct ToLog; friend struct ToLog;
template <typename T> Print &operator<<(T const &value) {} template <typename T> Print &operator<<(T const &value) {}
private: private:
}; };
struct ToLog { struct ToLog {
ToLog(LogLevel level) {} ToLog(LogLevel level) {}
template <typename T> ToLog &operator<<(T const &value) { template <typename T> ToLog &operator<<(T const &value) { return *this; }
return *this; };
}
};
#define LOG(level) \ #define LOG(level) \
if (true) { \ if (true) { \
} else \ } else \
paddle_mobile::ToLog(level) paddle_mobile::ToLog(level)
#define DLOG \ #define DLOG \
if (true) { \ if (true) { \
} else \ } else \
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG) paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
#define LOGF(level, format, ...) #define LOGF(level, format, ...)
#define DLOGF(format, ...) #define DLOGF(format, ...)
} } // namespace paddle_mobile
#endif #endif
...@@ -23,31 +23,30 @@ SOFTWARE. ...@@ -23,31 +23,30 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> class OperatorBase; template <typename Dtype> class OperatorBase;
class OpDesc; class OpDesc;
class BlockDesc; class BlockDesc;
class InferShapeContext; class InferShapeContext;
} } // namespace framework
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
template <typename Dtype> template <typename Dtype>
using OpCreator = std::function<framework::OperatorBase<Dtype> *( using OpCreator = std::function<framework::OperatorBase<Dtype> *(
const std::string & /*type*/, const VariableNameMap & /*inputs*/, const std::string & /*type*/, const VariableNameMap & /*inputs*/,
const VariableNameMap & /*outputs*/, const VariableNameMap & /*outputs*/,
const framework::AttributeMap & /*attrs*/)>; const framework::AttributeMap & /*attrs*/)>;
using GradOpMakerFN = using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<framework::OpDesc>>( std::function<std::vector<std::unique_ptr<framework::OpDesc>>(
const framework::OpDesc &, const framework::OpDesc &,
const std::unordered_set<std::string> & /*no_grad_set*/, const std::unordered_set<std::string> & /*no_grad_set*/,
std::unordered_map<std::string, std::string> * /*grad_to_var*/, std::unordered_map<std::string, std::string> * /*grad_to_var*/,
const std::vector<framework::BlockDesc *> &grad_block)>; const std::vector<framework::BlockDesc *> &grad_block)>;
using InferVarTypeFN = using InferVarTypeFN = std::function<void(const framework::OpDesc & /*op_desc*/,
std::function<void(const framework::OpDesc & /*op_desc*/, framework::BlockDesc * /*block*/)>;
framework::BlockDesc * /*block*/)>;
using InferShapeFN = std::function<void(framework::InferShapeContext *)>;
using InferShapeFN = std::function<void(framework::InferShapeContext *)>; }; // namespace paddle_mobile
};
...@@ -19,45 +19,45 @@ SOFTWARE. ...@@ -19,45 +19,45 @@ SOFTWARE.
#pragma once; #pragma once;
namespace paddle_mobile { namespace paddle_mobile {
enum class Precision : int { FP32 = 0 }; enum class Precision : int { FP32 = 0 };
//! device type //! device type
enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 }; enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2 };
template <DeviceTypeEnum T> struct DeviceType {}; template <DeviceTypeEnum T> struct DeviceType {};
typedef DeviceType<kCPU> CPU; typedef DeviceType<kCPU> CPU;
typedef DeviceType<kFPGA> FPGA; typedef DeviceType<kFPGA> FPGA;
typedef DeviceType<kGPU_MALI> GPU_MALI; typedef DeviceType<kGPU_MALI> GPU_MALI;
//! data type //! data type
enum DataType { enum DataType {
PM_INVALID = -1, PM_INVALID = -1,
PM_HALF = 0, PM_HALF = 0,
PM_FLOAT = 1, PM_FLOAT = 1,
PM_DOUBLE = 2, PM_DOUBLE = 2,
PM_INT8 = 3, PM_INT8 = 3,
PM_INT16 = 4, PM_INT16 = 4,
PM_INT32 = 5, PM_INT32 = 5,
PM_INT64 = 6, PM_INT64 = 6,
PM_UINT8 = 7, PM_UINT8 = 7,
PM_UINT16 = 8, PM_UINT16 = 8,
PM_UINT32 = 9, PM_UINT32 = 9,
PM_STRING = 10, PM_STRING = 10,
PM_BOOL = 11, PM_BOOL = 11,
PM_SHAPE = 12, PM_SHAPE = 12,
PM_TENSOR = 13 PM_TENSOR = 13
}; };
//! //!
enum PMStatus { enum PMStatus {
PMSuccess = 0xFF, /*!< No errors */ PMSuccess = 0xFF, /*!< No errors */
PMNotInitialized = 0x01, /*!< Data not initialized. */ PMNotInitialized = 0x01, /*!< Data not initialized. */
PMInvalidValue = 0x02, /*!< Incorrect variable value. */ PMInvalidValue = 0x02, /*!< Incorrect variable value. */
PMMemAllocFailed = 0x03, /*!< Memory allocation error. */ PMMemAllocFailed = 0x03, /*!< Memory allocation error. */
PMUnKownError = 0x04, /*!< Unknown error. */ PMUnKownError = 0x04, /*!< Unknown error. */
PMOutOfAuthority = 0x05, /*!< Try to modified data not your own*/ PMOutOfAuthority = 0x05, /*!< Try to modified data not your own*/
PMOutOfMem = 0x06, /*!< OOM error*/ PMOutOfMem = 0x06, /*!< OOM error*/
PMUnImplError = 0x07, /*!< Unimplement error. */ PMUnImplError = 0x07, /*!< Unimplement error. */
PMWrongDevice = 0x08 /*!< un-correct device. */ PMWrongDevice = 0x08 /*!< un-correct device. */
}; };
} } // namespace paddle_mobile
...@@ -15,5 +15,3 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, ...@@ -15,5 +15,3 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
==============================================================================*/ ==============================================================================*/
#include "variant.h"
...@@ -21,79 +21,79 @@ SOFTWARE. ...@@ -21,79 +21,79 @@ SOFTWARE.
#pragma once #pragma once
namespace paddle_mobile { namespace paddle_mobile {
template <int ID, typename Type> struct IDToType { typedef Type type_t; }; template <int ID, typename Type> struct IDToType { typedef Type type_t; };
template <typename F, typename... Ts> struct VariantHelper { template <typename F, typename... Ts> struct VariantHelper {
static const size_t size = sizeof(F) > VariantHelper<Ts...>::size static const size_t size = sizeof(F) > VariantHelper<Ts...>::size
? sizeof(F) ? sizeof(F)
: VariantHelper<Ts...>::size; : VariantHelper<Ts...>::size;
inline static void Destroy(size_t id, void *data) { inline static void Destroy(size_t id, void *data) {
if (id == typeid(F).hash_code()) { if (id == typeid(F).hash_code()) {
reinterpret_cast<F *>(data)->~F(); reinterpret_cast<F *>(data)->~F();
} else { } else {
VariantHelper<Ts...>::Destroy(id, data); VariantHelper<Ts...>::Destroy(id, data);
}
} }
}; }
};
template <typename F> struct VariantHelper<F> { template <typename F> struct VariantHelper<F> {
static const size_t size = sizeof(F); static const size_t size = sizeof(F);
inline static void Destroy(size_t id, void *data) { inline static void Destroy(size_t id, void *data) {
if (id == typeid(F).hash_code()) { if (id == typeid(F).hash_code()) {
// reinterpret_cast<F*>(data)->~F(); // reinterpret_cast<F*>(data)->~F();
} else { } else {
// std::cout << "未匹配到 " << std::endl; // std::cout << "未匹配到 " << std::endl;
}
} }
}; }
};
template <size_t size> class RawData { template <size_t size> class RawData {
public: public:
char data[size]; char data[size];
RawData() {} RawData() {}
RawData(const RawData &raw_data) { strcpy(data, raw_data.data); } RawData(const RawData &raw_data) { strcpy(data, raw_data.data); }
// void operator=(const RawData &raw_data){ // void operator=(const RawData &raw_data){
// strcpy(data, raw_data.data); // strcpy(data, raw_data.data);
// } // }
}; };
template <typename... Ts> struct Variant { template <typename... Ts> struct Variant {
Variant(const Variant &variant) { Variant(const Variant &variant) {
// std::cout << " 赋值构造函数 " << std::endl; // std::cout << " 赋值构造函数 " << std::endl;
type_id = variant.type_id; type_id = variant.type_id;
data = variant.data; data = variant.data;
} }
Variant() : type_id(invalid_type()) {} Variant() : type_id(invalid_type()) {}
~Variant() { ~Variant() {
// helper::Destroy(type_id, &data); // helper::Destroy(type_id, &data);
} }
template <typename T, typename... Args> void Set(Args &&... args) { template <typename T, typename... Args> void Set(Args &&... args) {
helper::Destroy(type_id, &data); helper::Destroy(type_id, &data);
new (&data) T(std::forward<Args>(args)...); new (&data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code(); type_id = typeid(T).hash_code();
} }
template <typename T> T &Get() const { template <typename T> T &Get() const {
if (type_id == typeid(T).hash_code()) { if (type_id == typeid(T).hash_code()) {
return *const_cast<T *>(reinterpret_cast<const T *>(&data)); return *const_cast<T *>(reinterpret_cast<const T *>(&data));
} else { } else {
// std::cout << " bad cast in variant " << std::endl; // std::cout << " bad cast in variant " << std::endl;
throw std::bad_cast(); throw std::bad_cast();
}
} }
}
size_t TypeId() const { return type_id; } size_t TypeId() const { return type_id; }
private: private:
static inline size_t invalid_type() { return typeid(void).hash_code(); } static inline size_t invalid_type() { return typeid(void).hash_code(); }
typedef VariantHelper<Ts...> helper; typedef VariantHelper<Ts...> helper;
size_t type_id; size_t type_id;
RawData<helper::size> data; RawData<helper::size> data;
}; };
template <typename T> struct Vistor { typedef T type_t; }; template <typename T> struct Vistor { typedef T type_t; };
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -16,8 +16,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE ...@@ -16,8 +16,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
==============================================================================*/ ==============================================================================*/
#include "attribute.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {} namespace framework {}
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,110 +22,108 @@ SOFTWARE. ...@@ -22,110 +22,108 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class BlockDesc; class BlockDesc;
class Attribute { class Attribute {
public: public:
static Attribute static Attribute GetAttrValue(const proto::OpDesc::Attr &attr_desc) {
GetAttrValue(const proto::OpDesc::Attr &attr_desc) { // std::cout << "begin get attr value" << std::endl;
// std::cout << "begin get attr value" << std::endl; Attribute attr;
Attribute attr; switch (attr_desc.type()) {
switch (attr_desc.type()) { case proto::AttrType::BOOLEAN: {
case proto::AttrType::BOOLEAN: { attr.Set<bool>(attr_desc.b());
attr.Set<bool>(attr_desc.b()); break;
break; }
} case proto::AttrType::INT: {
case proto::AttrType::INT: { attr.Set<int>(attr_desc.i());
attr.Set<int>(attr_desc.i()); break;
break; }
} case proto::AttrType::FLOAT: {
case proto::AttrType::FLOAT: { attr.Set<float>(attr_desc.f());
attr.Set<float>(attr_desc.f()); break;
break; }
} case proto::AttrType::STRING: {
case proto::AttrType::STRING: { attr.Set<std::string>(attr_desc.s());
attr.Set<std::string>(attr_desc.s()); break;
break; }
} case proto::AttrType::BOOLEANS: {
case proto::AttrType::BOOLEANS: { std::vector<bool> val(attr_desc.bools_size());
std::vector<bool> val(attr_desc.bools_size()); for (int i = 0; i < attr_desc.bools_size(); ++i) {
for (int i = 0; i < attr_desc.bools_size(); ++i) { val[i] = attr_desc.bools(i);
val[i] = attr_desc.bools(i);
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
}
// std::cout << "end get attr value" << std::endl;
return attr;
} }
attr.Set<std::vector<bool>>(val);
Attribute() {} break;
template <typename T, typename... Args> }
Attribute &Set(Args &&... args) { case proto::AttrType::INTS: {
variant_.Set<T>(args...); std::vector<int> val(attr_desc.ints_size());
return *this; for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
} }
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
}
// std::cout << "end get attr value" << std::endl;
return attr;
}
template <typename T> T &Get() const { return variant_.Get<T>(); } Attribute() {}
template <typename T, typename... Args> Attribute &Set(Args &&... args) {
variant_.Set<T>(args...);
return *this;
}
private: template <typename T> T &Get() const { return variant_.Get<T>(); }
Variant<int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc *, int64_t>
variant_;
};
using AttributeMap = std::unordered_map<std::string, Attribute>; private:
Variant<int, float, std::string, std::vector<int>, std::vector<float>,
std::vector<std::string>, bool, std::vector<bool>, BlockDesc *,
int64_t>
variant_;
};
class AttrReader { using AttributeMap = std::unordered_map<std::string, Attribute>;
public:
explicit AttrReader(const AttributeMap &attrs) : attrs_(attrs) {}
template <typename T> inline T Get(const std::string &name) const { class AttrReader {
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should public:
// be in explicit AttrReader(const AttributeMap &attrs) : attrs_(attrs) {}
// AttributeMap",
// name); template <typename T> inline T Get(const std::string &name) const {
return ((Attribute)attrs_.at(name)).Get<T>(); // PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
} // be in
// AttributeMap",
// name);
return ((Attribute)attrs_.at(name)).Get<T>();
}
private: private:
const AttributeMap &attrs_; const AttributeMap &attrs_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -19,32 +19,32 @@ SOFTWARE. ...@@ -19,32 +19,32 @@ SOFTWARE.
#include "block_desc.h" #include "block_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
std::vector<std::shared_ptr<VarDesc>> BlockDesc::Vars() const { std::vector<std::shared_ptr<VarDesc>> BlockDesc::Vars() const {
std::vector<std::shared_ptr<VarDesc>> res; std::vector<std::shared_ptr<VarDesc>> res;
for (const auto &p : vars_) { for (const auto &p : vars_) {
res.push_back(p.second); res.push_back(p.second);
} }
return res; return res;
} }
std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const { std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const {
std::vector<std::shared_ptr<OpDesc>> res; std::vector<std::shared_ptr<OpDesc>> res;
for (const auto &op : ops_) { for (const auto &op : ops_) {
res.push_back(op); res.push_back(op);
} }
return res; return res;
} }
BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) { BlockDesc::BlockDesc(const proto::BlockDesc &desc) : desc_(desc) {
for (const proto::VarDesc &var_desc : desc_.vars()) { for (const proto::VarDesc &var_desc : desc_.vars()) {
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_.ops()) { for (const proto::OpDesc &op_desc : desc_.ops()) {
ops_.emplace_back(new framework::OpDesc(op_desc)); ops_.emplace_back(new framework::OpDesc(op_desc));
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -24,50 +24,47 @@ SOFTWARE. ...@@ -24,50 +24,47 @@ SOFTWARE.
#include "var_desc.h" #include "var_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class BlockDesc : PaddleMobileObject { class BlockDesc : PaddleMobileObject {
public: public:
BlockDesc(const proto::BlockDesc &desc); BlockDesc(const proto::BlockDesc &desc);
const int &ID() const { return desc_.idx(); } const int &ID() const { return desc_.idx(); }
const int &Parent() const { return desc_.parent_idx(); } const int &Parent() const { return desc_.parent_idx(); }
bool operator==( bool operator==(const paddle_mobile::framework::BlockDesc &in_block) const {
const paddle_mobile::framework::BlockDesc &in_block) const { return this->ID() == in_block.ID() &&
return this->ID() == in_block.ID() && this->Parent() == in_block.Parent();
this->Parent() == in_block.Parent(); }
}
bool operator<( bool operator<(const paddle_mobile::framework::BlockDesc &in_block) const {
const paddle_mobile::framework::BlockDesc &in_block) const { return this->ID() < in_block.ID() && this->Parent() < in_block.Parent();
return this->ID() < in_block.ID() && }
this->Parent() < in_block.Parent();
}
std::vector<std::shared_ptr<VarDesc>> Vars() const; std::vector<std::shared_ptr<VarDesc>> Vars() const;
std::vector<std::shared_ptr<OpDesc>> Ops() const; std::vector<std::shared_ptr<OpDesc>> Ops() const;
private: private:
proto::BlockDesc desc_; proto::BlockDesc desc_;
std::vector<std::shared_ptr<OpDesc>> ops_; std::vector<std::shared_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_; std::unordered_map<std::string, std::shared_ptr<VarDesc>> vars_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
namespace std { namespace std {
template <> struct hash<paddle_mobile::framework::BlockDesc> { template <> struct hash<paddle_mobile::framework::BlockDesc> {
typedef paddle_mobile::framework::BlockDesc argument_type; typedef paddle_mobile::framework::BlockDesc argument_type;
typedef std::size_t result_type; typedef std::size_t result_type;
result_type operator()(argument_type const &s) const noexcept { result_type operator()(argument_type const &s) const noexcept {
result_type const h1(std::hash<int>{}(s.ID())); result_type const h1(std::hash<int>{}(s.ID()));
result_type const h2(std::hash<int>{}(s.ID())); result_type const h2(std::hash<int>{}(s.ID()));
return h1 ^ (h2 << 1); return h1 ^ (h2 << 1);
} }
}; };
} // namespace std } // namespace std
...@@ -19,50 +19,49 @@ limitations under the License. */ ...@@ -19,50 +19,49 @@ limitations under the License. */
#include <string> #include <string>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
enum class DataLayout { enum class DataLayout {
kNHWC = 0, kNHWC = 0,
kNCHW = 1, kNCHW = 1,
kAnyLayout = 2, kAnyLayout = 2,
}; };
inline DataLayout StringToDataLayout(const std::string &str) { inline DataLayout StringToDataLayout(const std::string &str) {
std::string s(str); std::string s(str);
for (size_t i = 0; i < s.size(); ++i) { for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]); s[i] = toupper(s[i]);
} }
if (s == "NHWC") { if (s == "NHWC") {
return DataLayout::kNHWC; return DataLayout::kNHWC;
} else if (s == "NCHW") { } else if (s == "NCHW") {
return DataLayout::kNCHW; return DataLayout::kNCHW;
} else if (s == "ANYLAYOUT") { } else if (s == "ANYLAYOUT") {
return DataLayout::kAnyLayout; return DataLayout::kAnyLayout;
} else { } else {
// std::cout << "Unknown storage order string: %s", s; // std::cout << "Unknown storage order string: %s", s;
} }
} }
inline std::string DataLayoutToString(const DataLayout &data_layout) { inline std::string DataLayoutToString(const DataLayout &data_layout) {
switch (data_layout) { switch (data_layout) {
case DataLayout::kNHWC: case DataLayout::kNHWC:
return "NHWC"; return "NHWC";
case DataLayout::kNCHW: case DataLayout::kNCHW:
return "NCHW"; return "NCHW";
case DataLayout::kAnyLayout: case DataLayout::kAnyLayout:
return "ANY_LAYOUT"; return "ANY_LAYOUT";
default: default:
break; break;
// std::cout << "unknown DataLayou %d", data_layout; // std::cout << "unknown DataLayou %d", data_layout;
} }
} }
inline std::ostream &operator<<(std::ostream &out, inline std::ostream &operator<<(std::ostream &out, const DataLayout &l) {
const DataLayout &l) { out << DataLayoutToString(l);
out << DataLayoutToString(l); return out;
return out; }
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,72 +21,72 @@ SOFTWARE. ...@@ -21,72 +21,72 @@ SOFTWARE.
#include "data_transform.h" #include "data_transform.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
static void PassTensorData(Tensor *from, Tensor *to) { static void PassTensorData(Tensor *from, Tensor *to) {
to->ShareDataWith(*from); to->ShareDataWith(*from);
*from = Tensor(); *from = Tensor();
} }
void DataTransform(const OpKernelType &expected_kernel_type, void DataTransform(const OpKernelType &expected_kernel_type,
const OpKernelType &kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor &input_tensor, Tensor *output_tensor) { const Tensor &input_tensor, Tensor *output_tensor) {
bool transformed = false; bool transformed = false;
Tensor in; Tensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
Tensor out; Tensor out;
// // do layout transform // // do layout transform
// if (NeedTransformLayout(expected_kernel_type.data_layout_, // if (NeedTransformLayout(expected_kernel_type.data_layout_,
// kernel_type_for_var.data_layout_)) { // kernel_type_for_var.data_layout_)) {
// TransDataLayout(kernel_type_for_var, expected_kernel_type, in, // TransDataLayout(kernel_type_for_var, expected_kernel_type, in,
// &out); // &out);
// transformed = true; // transformed = true;
// PassTensorData(&out, &in); // PassTensorData(&out, &in);
// } // }
// //
// // do data type transform // // do data type transform
// if (expected_kernel_type.data_type_ != // if (expected_kernel_type.data_type_ !=
// kernel_type_for_var.data_type_) { // kernel_type_for_var.data_type_) {
// TransDataType(kernel_type_for_var, expected_kernel_type, in, // TransDataType(kernel_type_for_var, expected_kernel_type, in,
// &out); // &out);
// transformed = true; // transformed = true;
// PassTensorData(&out, &in); // PassTensorData(&out, &in);
// } // }
// //
// // do device transform // // do device transform
// if (!platform::is_same_place(kernel_type_for_var.place_, // if (!platform::is_same_place(kernel_type_for_var.place_,
// expected_kernel_type.place_)) { // expected_kernel_type.place_)) {
// TransDataDevice(in, expected_kernel_type.place_, &out); // TransDataDevice(in, expected_kernel_type.place_, &out);
// transformed = true; // transformed = true;
// PassTensorData(&out, &in); // PassTensorData(&out, &in);
// } // }
// //
// PADDLE_ENFORCE(transformed, "No transform is applied, please // PADDLE_ENFORCE(transformed, "No transform is applied, please
// check!"); // check!");
// get output data // get output data
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
void CopyVariableWithTensor(const Variable &in_var, void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
const Tensor &tensor, Variable &out_var) { Variable &out_var) {
// if (in_var.IsType<LoDTensor>()) { // if (in_var.IsType<LoDTensor>()) {
// auto& in_lod_tensor = in_var.Get<LoDTensor>(); // auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>(); // auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
// tran_lod_tensor->set_lod(in_lod_tensor.lod()); // tran_lod_tensor->set_lod(in_lod_tensor.lod());
// tran_lod_tensor->set_layout(in_lod_tensor.layout()); // tran_lod_tensor->set_layout(in_lod_tensor.layout());
// tran_lod_tensor->ShareDataWith(tensor); // tran_lod_tensor->ShareDataWith(tensor);
// } else if (in_var.IsType<SelectedRows>()) { // } else if (in_var.IsType<SelectedRows>()) {
// auto& in_selected_rows = in_var.Get<SelectedRows>(); // auto& in_selected_rows = in_var.Get<SelectedRows>();
// auto* trans_selected_rows = // auto* trans_selected_rows =
// out_var.GetMutable<SelectedRows>(); // out_var.GetMutable<SelectedRows>();
// trans_selected_rows->set_height(in_selected_rows.height()); // trans_selected_rows->set_height(in_selected_rows.height());
// trans_selected_rows->set_rows(in_selected_rows.rows()); // trans_selected_rows->set_rows(in_selected_rows.rows());
// trans_selected_rows->mutable_value()->ShareDataWith(tensor); // trans_selected_rows->mutable_value()->ShareDataWith(tensor);
// } else { // } else {
// PADDLE_THROW("unknown var type"); // PADDLE_THROW("unknown var type");
// } // }
} }
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,14 +28,14 @@ SOFTWARE. ...@@ -28,14 +28,14 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void DataTransform(const OpKernelType &expected_kernel_type, void DataTransform(const OpKernelType &expected_kernel_type,
const OpKernelType &kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor &input_tensor, Tensor *out); const Tensor &input_tensor, Tensor *out);
void CopyVariableWithTensor(const Variable &in_var, void CopyVariableWithTensor(const Variable &in_var, const Tensor &tensor,
const Tensor &tensor, Variable &out_var); Variable &out_var);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -21,23 +21,23 @@ SOFTWARE. ...@@ -21,23 +21,23 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
// inline proto::VarType::Type ToDataType(std::type_index type) { // inline proto::VarType::Type ToDataType(std::type_index type) {
// using namespace paddle_mobile::framework::proto; // using namespace paddle_mobile::framework::proto;
// if (typeid(float).hash_code() == type.hash_code()) { // if (typeid(float).hash_code() == type.hash_code()) {
// return proto::VarType::FP32; // return proto::VarType::FP32;
// } else if (typeid(double).hash_code() == type.hash_code()) { // } else if (typeid(double).hash_code() == type.hash_code()) {
// return proto::VarType::FP64; // return proto::VarType::FP64;
// } else if (typeid(int).hash_code() == type.hash_code()) { // } else if (typeid(int).hash_code() == type.hash_code()) {
// return proto::VarType::INT32; // return proto::VarType::INT32;
// } else if (typeid(int64_t).hash_code() == type.hash_code()) { // } else if (typeid(int64_t).hash_code() == type.hash_code()) {
// return proto::VarType::INT64; // return proto::VarType::INT64;
// } else if (typeid(bool).hash_code() == type.hash_code()) { // } else if (typeid(bool).hash_code() == type.hash_code()) {
// return proto::VarType::BOOL; // return proto::VarType::BOOL;
// } else { // } else {
//// PADDLE_THROW("Not supported"); //// PADDLE_THROW("Not supported");
// } // }
// } // }
} }
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -15,320 +15,318 @@ limitations under the License. */ ...@@ -15,320 +15,318 @@ limitations under the License. */
#include "ddim.h" #include "ddim.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/// @cond HIDDEN /// @cond HIDDEN
template <int i> Dim<i> make_dim(const int64_t *d) { template <int i> Dim<i> make_dim(const int64_t *d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1)); return Dim<i>(*d, make_dim<i - 1>(d + 1));
} }
template <> Dim<0> make_dim<0>(const int64_t *d) { return Dim<0>(*d); } template <> Dim<0> make_dim<0>(const int64_t *d) { return Dim<0>(*d); }
void make_ddim(DDim &ddim, const int64_t *dims, int n) { void make_ddim(DDim &ddim, const int64_t *dims, int n) {
switch (n) { switch (n) {
case 0: case 0:
ddim = make_dim<0>(dims); ddim = make_dim<0>(dims);
break; break;
case 1: case 1:
ddim = make_dim<1>(dims); ddim = make_dim<1>(dims);
break; break;
case 2: case 2:
ddim = make_dim<2>(dims); ddim = make_dim<2>(dims);
break; break;
case 3: case 3:
ddim = make_dim<3>(dims); ddim = make_dim<3>(dims);
break; break;
case 4: case 4:
ddim = make_dim<4>(dims); ddim = make_dim<4>(dims);
break; break;
case 5: case 5:
ddim = make_dim<5>(dims); ddim = make_dim<5>(dims);
break; break;
case 6: case 6:
ddim = make_dim<6>(dims); ddim = make_dim<6>(dims);
break; break;
case 7: case 7:
ddim = make_dim<7>(dims); ddim = make_dim<7>(dims);
break; break;
case 8: case 8:
ddim = make_dim<8>(dims); ddim = make_dim<8>(dims);
break; break;
case 9: case 9:
ddim = make_dim<9>(dims); ddim = make_dim<9>(dims);
break; break;
default: default:
// std::cout << "Dynamic dimensions must have between [1, // std::cout << "Dynamic dimensions must have between [1,
// 9] // 9]
// dimensions."; // dimensions.";
break; break;
} }
} }
/// @endcond /// @endcond
DDim make_ddim(std::initializer_list<int64_t> dims) { DDim make_ddim(std::initializer_list<int64_t> dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, dims.begin(), dims.size()); make_ddim(result, dims.begin(), dims.size());
return result; return result;
} }
DDim make_ddim(const std::vector<int64_t> &dims) { DDim make_ddim(const std::vector<int64_t> &dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, &dims[0], dims.size()); make_ddim(result, &dims[0], dims.size());
return result; return result;
} }
DDim make_ddim(const std::vector<int> &dims) { DDim make_ddim(const std::vector<int> &dims) {
std::vector<int64_t> res(dims.size()); std::vector<int64_t> res(dims.size());
std::transform(dims.begin(), dims.end(), res.begin(), std::transform(dims.begin(), dims.end(), res.begin(),
[](int d) { return static_cast<int64_t>(d); }); [](int d) { return static_cast<int64_t>(d); });
return make_ddim(res); return make_ddim(res);
} }
/// @cond HIDDEN /// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes // XXX For some reason, putting this in an anonymous namespace causes
// errors // errors
struct DynamicMutableIndexer : Vistor<int64_t &> { struct DynamicMutableIndexer : Vistor<int64_t &> {
public: public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {} explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D> int64_t &operator()(Dim<D> &dim) const { template <int D> int64_t &operator()(Dim<D> &dim) const {
return dim[idx_]; return dim[idx_];
} }
private: private:
int idx_; int idx_;
}; };
struct DynamicConstIndexer : public Vistor<int64_t> { struct DynamicConstIndexer : public Vistor<int64_t> {
public: public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {} explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D> int64_t operator()(const Dim<D> &dim) const { template <int D> int64_t operator()(const Dim<D> &dim) const {
return dim[idx_]; return dim[idx_];
} }
private: private:
int idx_; int idx_;
}; };
/// @endcond /// @endcond
int64_t &DDim::operator[](int idx) { int64_t &DDim::operator[](int idx) {
return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this); return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
} }
int64_t DDim::operator[](int idx) const { int64_t DDim::operator[](int idx) const {
return DDim::ApplyVistor(DynamicConstIndexer(idx), *this); return DDim::ApplyVistor(DynamicConstIndexer(idx), *this);
} }
int DDim::size() const { return arity(*this); } int DDim::size() const { return arity(*this); }
bool DDim::operator==(DDim d) const { bool DDim::operator==(DDim d) const {
// if (var.which() != d.getVar().which()) { // if (var.which() != d.getVar().which()) {
// return false; // return false;
// } else { // } else {
std::vector<int64_t> v1 = vectorize(*this); std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d); std::vector<int64_t> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) { for (unsigned int i = 0; i < v1.size(); i++) {
if (v1[i] != v2[i]) { if (v1[i] != v2[i]) {
return false; return false;
}
}
return true;
// }
}
bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] + v2[i]);
}
return make_ddim(v3);
}
DDim DDim::operator*(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
}
return make_ddim(v3);
} }
}
int64_t get(const DDim &ddim, int idx) { return ddim[idx]; } return true;
// }
void set(DDim &ddim, int idx, int value) { ddim[idx] = value; } }
/// @cond HIDDEN bool DDim::operator!=(DDim d) const { return !(*this == d); }
struct VectorizeVisitor : Vistor<void> {
std::vector<int64_t> &vector; DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {} std::vector<int64_t> v2 = vectorize(d);
template <typename T> void operator()(const T &t) { std::vector<int64_t> v3;
vector.push_back(t.head);
this->operator()(t.tail); assert(v1.size() == v2.size());
}
for (unsigned int i = 0; i < v1.size(); i++) {
void operator()(const Dim<0> &t) {} v3.push_back(v1[i] + v2[i]);
}; }
/// @endcond
return make_ddim(v3);
std::vector<int64_t> vectorize(const DDim &ddim) { }
std::vector<int64_t> result;
VectorizeVisitor visitor(result); DDim DDim::operator*(DDim d) const {
DDim::ApplyVistor(visitor, ddim); std::vector<int64_t> v1 = vectorize(*this);
return result; std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
}
return make_ddim(v3);
}
int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }
void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }
/// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> {
std::vector<int64_t> &vector;
explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
template <typename T> void operator()(const T &t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
void operator()(const Dim<0> &t) {}
};
/// @endcond
std::vector<int64_t> vectorize(const DDim &ddim) {
std::vector<int64_t> result;
VectorizeVisitor visitor(result);
DDim::ApplyVistor(visitor, ddim);
return result;
}
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim &ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
}
struct ProductVisitor : Vistor<int64_t> {
template <int D> int64_t operator()(const Dim<D> &dim) {
return product(dim);
}
};
int64_t product(const DDim &ddim) {
ProductVisitor visitor;
return DDim::ApplyVistor(visitor, ddim);
}
struct SliceVectorizeVisitor : Vistor<void> {
std::vector<int64_t> &vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
: vector(v), begin(b), end(e) {
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in
// ddim
// slice.");
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in
// ddim slice.");
}
template <int S> void operator()(const Dim<S> &dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
} }
--end;
// NOTE: framework::vectorize converts to type int64_t if (end > 0) {
// which does not fit cudnn inputs. this->operator()(dim.tail);
std::vector<int> vectorize2int(const DDim &ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
} }
}
struct ProductVisitor : Vistor<int64_t> {
template <int D> int64_t operator()(const Dim<D> &dim) { void operator()(const Dim<0> &dim) {
return product(dim); // PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
} // of bound.");
}; }
};
int64_t product(const DDim &ddim) {
ProductVisitor visitor; DDim slice_ddim(const DDim &ddim, int begin, int end) {
return DDim::ApplyVistor(visitor, ddim); std::vector<int64_t> vec;
} vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end);
struct SliceVectorizeVisitor : Vistor<void> { // boost::apply_visitor(visitor, dim);
std::vector<int64_t> &vector; DDim::ApplyVistor(visitor, ddim);
int begin; // visitor(ddim.var.Get<Dim<4>>());
int end; return make_ddim(vec);
}
SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
: vector(v), begin(b), end(e) { /// \cond HIDDEN
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in struct ArityVisitor : Vistor<int> {
// ddim template <int D> int operator()(Dim<D>) const { return D; }
// slice."); };
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in /// \endcond
// ddim slice.");
} int arity(const DDim &d) {
ArityVisitor arityVisitor = ArityVisitor();
template <int S> void operator()(const Dim<S> &dim) { return DDim::ApplyVistor(arityVisitor, d);
if (begin == 0) { // return arityVisitor(d.var.Get<Dim<4>>());
vector.push_back(dim.head); // return boost::apply_visitor(ArityVisitor(), d); }
} else { }
--begin; /// \cond HIDDEN
}
--end; /// \endcond
if (end > 0) {
this->operator()(dim.tail); struct OSVistor : Vistor<std::ostream &> {
} OSVistor(std::ostream &os) : os_(os) {}
}
template <int D> std::ostream &operator()(Dim<D> dim) const {
void operator()(const Dim<0> &dim) { return os_ << dim;
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out }
// of bound.");
} private:
}; std::ostream &os_;
};
DDim slice_ddim(const DDim &ddim, int begin, int end) {
std::vector<int64_t> vec; std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
vec.reserve(end - begin); auto vistor = OSVistor(os);
SliceVectorizeVisitor visitor(vec, begin, end); DDim::ApplyVistor(vistor, ddim);
// boost::apply_visitor(visitor, dim); return os;
DDim::ApplyVistor(visitor, ddim); }
// visitor(ddim.var.Get<Dim<4>>());
return make_ddim(vec); DDim::DDim(std::initializer_list<int64_t> init_list) {
} *this = make_ddim(init_list);
}
/// \cond HIDDEN
DDim flatten_to_2d(const DDim &src, int num_col_dims) {
struct ArityVisitor : Vistor<int> { int rank = src.size();
template <int D> int operator()(Dim<D>) const { return D; } return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
}; product(slice_ddim(src, num_col_dims, rank))});
}
/// \endcond
DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }
int arity(const DDim &d) {
ArityVisitor arityVisitor = ArityVisitor(); DDim stride(const DDim &ddim) {
return DDim::ApplyVistor(arityVisitor, d); std::vector<int64_t> strides(ddim.size());
// return arityVisitor(d.var.Get<Dim<4>>()); strides[ddim.size() - 1] = 1;
// return boost::apply_visitor(ArityVisitor(), d); } for (int i = ddim.size() - 2; i >= 0; --i) {
} strides[i] = strides[i + 1] * ddim[i + 1];
/// \cond HIDDEN }
return framework::make_ddim(strides);
/// \endcond }
struct OSVistor : Vistor<std::ostream &> { DDim stride_numel(const framework::DDim &ddim) {
OSVistor(std::ostream &os) : os_(os) {} std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
template <int D> std::ostream &operator()(Dim<D> dim) const { for (int i = ddim.size() - 2; i >= 0; --i) {
return os_ << dim; strides[i] = strides[i + 1] * ddim[i];
} }
return framework::make_ddim(strides);
private: }
std::ostream &os_;
}; } // namespace framework
std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
auto vistor = OSVistor(os);
DDim::ApplyVistor(vistor, ddim);
return os;
}
DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list);
}
DDim flatten_to_2d(const DDim &src, int num_col_dims) {
int rank = src.size();
return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, rank))});
}
DDim flatten_to_1d(const DDim &src) {
return make_ddim({product(src)});
}
DDim stride(const DDim &ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return framework::make_ddim(strides);
}
DDim stride_numel(const framework::DDim &ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return framework::make_ddim(strides);
}
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,145 +22,142 @@ limitations under the License. */ ...@@ -22,145 +22,142 @@ limitations under the License. */
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/** /**
* \brief A dynamically sized dimension. * \brief A dynamically sized dimension.
* *
* The number of dimensions must be between [1, 9]. * The number of dimensions must be between [1, 9].
*/ */
struct DDim { struct DDim {
typedef Variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, typedef Variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
Dim<6>, Dim<7>, Dim<8>, Dim<9>> Dim<7>, Dim<8>, Dim<9>>
DDimVar; DDimVar;
DDimVar var; DDimVar var;
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, static typename Vistor::type_t ApplyVistor(Vistor vistor, const DDim &d) {
const DDim &d) { if (d.var.TypeId() == typeid(Dim<0>).hash_code()) {
if (d.var.TypeId() == typeid(Dim<0>).hash_code()) { return vistor(d.var.Get<Dim<0>>());
return vistor(d.var.Get<Dim<0>>()); } else if (d.var.TypeId() == typeid(Dim<1>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<1>).hash_code()) { return vistor(d.var.Get<Dim<1>>());
return vistor(d.var.Get<Dim<1>>()); } else if (d.var.TypeId() == typeid(Dim<2>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<2>).hash_code()) { return vistor(d.var.Get<Dim<2>>());
return vistor(d.var.Get<Dim<2>>()); } else if (d.var.TypeId() == typeid(Dim<3>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<3>).hash_code()) { return vistor(d.var.Get<Dim<3>>());
return vistor(d.var.Get<Dim<3>>()); } else if (d.var.TypeId() == typeid(Dim<4>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<4>).hash_code()) { return vistor(d.var.Get<Dim<4>>());
return vistor(d.var.Get<Dim<4>>()); } else if (d.var.TypeId() == typeid(Dim<5>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<5>).hash_code()) { return vistor(d.var.Get<Dim<5>>());
return vistor(d.var.Get<Dim<5>>()); } else if (d.var.TypeId() == typeid(Dim<6>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<6>).hash_code()) { return vistor(d.var.Get<Dim<6>>());
return vistor(d.var.Get<Dim<6>>()); } else if (d.var.TypeId() == typeid(Dim<7>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<7>).hash_code()) { return vistor(d.var.Get<Dim<7>>());
return vistor(d.var.Get<Dim<7>>()); } else if (d.var.TypeId() == typeid(Dim<8>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<8>).hash_code()) { return vistor(d.var.Get<Dim<8>>());
return vistor(d.var.Get<Dim<8>>()); } else if (d.var.TypeId() == typeid(Dim<9>).hash_code()) {
} else if (d.var.TypeId() == typeid(Dim<9>).hash_code()) { return vistor(d.var.Get<Dim<9>>());
return vistor(d.var.Get<Dim<9>>()); } else {
} else { printf(" dim not support \n");
printf(" dim not support \n"); throw std::bad_exception();
throw std::bad_exception(); // return typename Vistor::type_t();
// return typename Vistor::type_t(); }
} }
}
DDim() { var.Set<Dim<1>>(Dim<1>()); }
DDim() { var.Set<Dim<1>>(Dim<1>()); }
template <int D> explicit DDim(const Dim<D> &in) { var.Set<Dim<D>>(in); }
template <int D> explicit DDim(const Dim<D> &in) {
var.Set<Dim<D>>(in); /*implicit*/ DDim(std::initializer_list<int64_t> init_list);
}
template <int D> DDim &operator=(const Dim<D> &in) {
/*implicit*/ DDim(std::initializer_list<int64_t> init_list); var.Set<Dim<D>>(in);
return *this;
template <int D> DDim &operator=(const Dim<D> &in) { }
var.Set<Dim<D>>(in);
return *this; int64_t &operator[](int idx);
}
int64_t operator[](int idx) const;
int64_t &operator[](int idx);
// template <typename Visitor>
int64_t operator[](int idx) const; // typename Visitor::result_type apply_visitor(Visitor& visitor) {
// return var.apply_visitor(visitor);
// template <typename Visitor> // }
// typename Visitor::result_type apply_visitor(Visitor& visitor) { //
// return var.apply_visitor(visitor); // template <typename Visitor>
// } // typename Visitor::result_type apply_visitor(Visitor& visitor)
// // const {
// template <typename Visitor> // return var.apply_visitor(visitor);
// typename Visitor::result_type apply_visitor(Visitor& visitor) // }
// const {
// return var.apply_visitor(visitor); DDimVar getVar() { return var; }
// }
bool operator==(DDim d) const;
DDimVar getVar() { return var; }
bool operator!=(DDim d) const;
bool operator==(DDim d) const;
DDim operator+(DDim d) const;
bool operator!=(DDim d) const;
DDim operator*(DDim d) const;
DDim operator+(DDim d) const;
int size() const;
DDim operator*(DDim d) const; };
int size() const; /**
}; * \brief Make a DDim from std::vector<int64_t>
*
/** * \param dims An vector of ints. Must be sized between [1, 9]
* \brief Make a DDim from std::vector<int64_t> */
* DDim make_ddim(const std::vector<int64_t> &dims);
* \param dims An vector of ints. Must be sized between [1, 9]
*/ DDim make_ddim(const std::vector<int> &dims);
DDim make_ddim(const std::vector<int64_t> &dims);
/**
DDim make_ddim(const std::vector<int> &dims); * \brief Make a DDim from an initializer list
*
/** * \param dims An initializer list of ints. Must be sized between [1, 9]
* \brief Make a DDim from an initializer list *
* */
* \param dims An initializer list of ints. Must be sized between [1, 9] DDim make_ddim(std::initializer_list<int64_t> dims);
*
*/ int64_t get(const DDim &dim, int idx);
DDim make_ddim(std::initializer_list<int64_t> dims);
void set(DDim &dim, int idx, int val);
int64_t get(const DDim &dim, int idx);
std::vector<int64_t> vectorize(const DDim &ddim);
void set(DDim &dim, int idx, int val);
std::vector<int> vectorize2int(const DDim &ddim);
std::vector<int64_t> vectorize(const DDim &ddim);
int64_t product(const DDim &ddim);
std::vector<int> vectorize2int(const DDim &ddim);
/**
int64_t product(const DDim &ddim); * \brief Slice a ddim
*
/** * Slice dim with [begin, end).
* \brief Slice a ddim * e.g. DDim d = make_ddim({1,2,3,4,5});
* * slice_ddim(d, 1, 3); ====> {2,3}
* Slice dim with [begin, end). */
* e.g. DDim d = make_ddim({1,2,3,4,5}); DDim slice_ddim(const DDim &dim, int begin, int end);
* slice_ddim(d, 1, 3); ====> {2,3}
*/ /**
DDim slice_ddim(const DDim &dim, int begin, int end); * \brief What is the length of this dimension?
*
/** * \param Dynamic dimension to inspect
* \brief What is the length of this dimension? */
*
* \param Dynamic dimension to inspect
*/
int arity(const DDim &ddim); int arity(const DDim &ddim);
std::ostream &operator<<(std::ostream &, const DDim &); std::ostream &operator<<(std::ostream &, const DDim &);
// Reshape a tensor to a matrix. The matrix's first dimension(column // Reshape a tensor to a matrix. The matrix's first dimension(column
// length) // length)
// will be the product of tensor's first `num_col_dims` dimensions. // will be the product of tensor's first `num_col_dims` dimensions.
DDim flatten_to_2d(const DDim &src, int num_col_dims); DDim flatten_to_2d(const DDim &src, int num_col_dims);
DDim flatten_to_1d(const DDim &src); DDim flatten_to_1d(const DDim &src);
DDim stride(const DDim &ddim); DDim stride(const DDim &ddim);
DDim stride_numel(const DDim &ddim); DDim stride_numel(const DDim &ddim);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
...@@ -20,78 +20,74 @@ SOFTWARE. ...@@ -20,78 +20,74 @@ SOFTWARE.
#include "executor.h" #include "executor.h"
#include "lod_tensor.h" #include "lod_tensor.h"
#include "operators/conv_op.h" #include "operators/conv_op.h"
#include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype>
Executor<Dtype>::Executor(const Program<Dtype> p) : program_(p) { Executor<Dtype>::Executor(const Program<Dtype> p) : program_(p) {
if (use_optimize_) { if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram; to_predict_program_ = program_.optimizeProgram;
} else { } else {
to_predict_program_ = program_.originProgram; to_predict_program_ = program_.originProgram;
} }
const std::vector<std::shared_ptr<BlockDesc>> blocks = const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks(); to_predict_program_->Blocks();
for (int i = 0; i < blocks.size(); ++i) { for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i]; std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops(); std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < ops.size(); ++j) { for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j]; std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "conv2d" && if (op->Type() == "conv2d" && op->Input("Input")[0] == "pixel") {
op->Input("Input")[0] == "pixel") { Attribute strides_attr = op->GetAttrMap().at("strides");
Attribute strides_attr = op->GetAttrMap().at("strides"); std::vector<int> stride = strides_attr.Get<std::vector<int>>();
std::vector<int> stride = for (int k = 0; k < stride.size(); ++k) {
strides_attr.Get<std::vector<int>>();
for (int k = 0; k < stride.size(); ++k) {
}
std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
std::make_shared<operators::ConvOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(conv);
}
} }
std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
std::make_shared<operators::ConvOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(conv);
} }
} }
}
}
template <typename Dtype> template <typename Dtype>
std::shared_ptr<Tensor> Executor<Dtype>::predict(Tensor &t) { std::shared_ptr<Tensor> Executor<Dtype>::predict(Tensor &t) {
// feed // feed
auto scope = program_.scope; auto scope = program_.scope;
Variable *g_feed_value = scope->Var("pixel"); Variable *g_feed_value = scope->Var("pixel");
auto tensor = g_feed_value->GetMutable<Tensor>(); auto tensor = g_feed_value->GetMutable<Tensor>();
tensor->ShareDataWith(t); tensor->ShareDataWith(t);
Variable *con_output = scope->Var("conv2d_0.tmp_0"); Variable *con_output = scope->Var("conv2d_0.tmp_0");
Tensor *output_tensor = con_output->GetMutable<Tensor>(); Tensor *output_tensor = con_output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({1, 16, 32, 32}); output_tensor->mutable_data<float>({1, 16, 32, 32});
// std::cout << typeid(output_tensor).name() << std::endl; // std::cout << typeid(output_tensor).name() << std::endl;
// std::cout << "output_tensor dims: " << output_tensor->dims() << // std::cout << "output_tensor dims: " << output_tensor->dims() <<
// std::endl; // std::endl;
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>(); std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor); out_tensor.reset(output_tensor);
predict(t, 0); predict(t, 0);
return out_tensor; return out_tensor;
} }
template <typename Dtype> template <typename Dtype>
void Executor<Dtype>::predict(const Tensor &t, int block_id) { void Executor<Dtype>::predict(const Tensor &t, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block = std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id); to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
++j) { auto op = ops_of_block_[*to_predict_block.get()][j];
auto op = ops_of_block_[*to_predict_block.get()][j]; // std::cout << "开始run" << std::endl;
// std::cout << "开始run" << std::endl; op->Run();
op->Run(); }
} }
}
template class Executor<CPU>; template class Executor<CPU>;
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -32,22 +32,22 @@ SOFTWARE. ...@@ -32,22 +32,22 @@ SOFTWARE.
#include "variable.h" #include "variable.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> class Executor { template <typename Dtype> class Executor {
public: public:
Executor(const Program<Dtype> p); Executor(const Program<Dtype> p);
std::shared_ptr<Tensor> predict(Tensor &t); std::shared_ptr<Tensor> predict(Tensor &t);
private: private:
const framework::Program<Dtype> program_; const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_; std::shared_ptr<ProgramDesc> to_predict_program_;
void predict(const Tensor &t, int block_id); void predict(const Tensor &t, int block_id);
std::map<framework::BlockDesc, std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>> std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_; ops_of_block_;
bool use_optimize_ = false; bool use_optimize_ = false;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -23,190 +23,186 @@ limitations under the License. */ ...@@ -23,190 +23,186 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
/* /*
* LoD is short for Level of Details. * LoD is short for Level of Details.
* *
* - in a level, each element indicates relative offset of the lower * - in a level, each element indicates relative offset of the lower
* level * level
* - the first element should be 0 and that indicates that this sequence * - the first element should be 0 and that indicates that this sequence
* start * start
* from 0 * from 0
* - each sequence's begin and end(no-inclusive) is level[id, id+1] * - each sequence's begin and end(no-inclusive) is level[id, id+1]
* *
* For example: * For example:
* 3-level LoD stores * 3-level LoD stores
* *
* 0 2 3 * 0 2 3
* 0 2 4 7 * 0 2 4 7
* 0 2 5 7 10 12 15 20 * 0 2 5 7 10 12 15 20
*/ */
using LoD = std::vector<std::vector<size_t>>; using LoD = std::vector<std::vector<size_t>>;
std::ostream &operator<<(std::ostream &os, const LoD &lod); std::ostream &operator<<(std::ostream &os, const LoD &lod);
std::ostream &operator<<(std::ostream &os, const LoDTensor &t); std::ostream &operator<<(std::ostream &os, const LoDTensor &t);
std::string LoDToString(const LoD &lod); std::string LoDToString(const LoD &lod);
LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end); size_t elem_end);
/* /*
* Transform an LoD from relative offsets to absolute offsets. * Transform an LoD from relative offsets to absolute offsets.
*/ */
LoD ToAbsOffset(const LoD &in); LoD ToAbsOffset(const LoD &in);
bool operator==(const LoD &a, const LoD &b); bool operator==(const LoD &a, const LoD &b);
/* /*
* Check whether this lod's format is valid. * Check whether this lod's format is valid.
* *
* ATTENTION: * ATTENTION:
* - Empty lod is treated as valid. * - Empty lod is treated as valid.
* *
* It will check two things: * It will check two things:
* *
* 1. all the offsets in a level should be ascending(no same items * 1. all the offsets in a level should be ascending(no same items
* allows). * allows).
* 2. there should be more than 2 offsets existing in each level. * 2. there should be more than 2 offsets existing in each level.
* 3. the higher level's last offset should equals the lower level's * 3. the higher level's last offset should equals the lower level's
* size-1. * size-1.
* 4. the first offset(the begin offset) of each level should be 0. * 4. the first offset(the begin offset) of each level should be 0.
* 5. the lowest level's last offset should equals `tensor_height` if * 5. the lowest level's last offset should equals `tensor_height` if
* tensor_height>0. * tensor_height>0.
*/ */
bool CheckLoD(const LoD &in, int tensor_height = -1); bool CheckLoD(const LoD &in, int tensor_height = -1);
/* /*
* Check whether this absolute lod's format is valid. * Check whether this absolute lod's format is valid.
* *
* ATTENTION: * ATTENTION:
* - Empty lod is treated as valid. * - Empty lod is treated as valid.
* *
* It will check two things: * It will check two things:
* 1. all the offsets in a level should be ascending(no same items * 1. all the offsets in a level should be ascending(no same items
* allows) * allows)
* 2. there should be more than 2 offsets existing in each level. * 2. there should be more than 2 offsets existing in each level.
* 3. the first offset of each level should be 0, and the last should * 3. the first offset of each level should be 0, and the last should
* be the * be the
* same(the height of underlying tensor) or `tensor_height` if * same(the height of underlying tensor) or `tensor_height` if
* tensor_height>0. * tensor_height>0.
*/ */
bool CheckAbsLoD(const LoD &in, int tensor_height = -1); bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
/* /*
* LoDTensor (Level of details Tensor) * LoDTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference. * see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/ */
class LoDTensor : public Tensor { class LoDTensor : public Tensor {
public: public:
LoDTensor() : Tensor() {} LoDTensor() : Tensor() {}
explicit LoDTensor(const LoD &lod) : lod_(lod) {} explicit LoDTensor(const LoD &lod) : lod_(lod) {}
void set_lod(const LoD &lod) { lod_ = lod; } void set_lod(const LoD &lod) { lod_ = lod; }
const LoD &lod() const { return lod_; } const LoD &lod() const { return lod_; }
LoD *mutable_lod() { return &lod_; } LoD *mutable_lod() { return &lod_; }
/* /*
* Get the start offset and end offset of an element from LoD. * Get the start offset and end offset of an element from LoD.
*/ */
std::pair<size_t, size_t> lod_element(size_t level, std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const {
size_t elem) const { // PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(level, NumLevels()); // PADDLE_ENFORCE_LT(elem, NumElements(level));
// PADDLE_ENFORCE_LT(elem, NumElements(level)); return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]);
return std::make_pair((lod_)[level][elem], }
(lod_)[level][elem + 1]);
} /*
* Number of LoDTensor's levels, each level has units of data, for
/* * example,
* Number of LoDTensor's levels, each level has units of data, for * in the sentence's view, article, paragraph, sentence are 3
* example, * levels.
* in the sentence's view, article, paragraph, sentence are 3 */
* levels. size_t NumLevels() const { return lod_.size(); }
*/
size_t NumLevels() const { return lod_.size(); } /*
* Number of elements in a level.
/* */
* Number of elements in a level. size_t NumElements(size_t level = 0) const {
*/ // PADDLE_ENFORCE_LT(level, NumLevels());
size_t NumElements(size_t level = 0) const { // the last offset is the end of last element
// PADDLE_ENFORCE_LT(level, NumLevels()); return (lod_)[level].size() - 1;
// the last offset is the end of last element }
return (lod_)[level].size() - 1;
} private:
LoD lod_;
private: };
LoD lod_;
}; /*
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
/* * LoDTensor is
* Expand the `source` to fit the LoD of `lod`. For example, a `source` * - LoD: [0, 2]
* LoDTensor is * - tensor: [a0, a1]
* - LoD: [0, 2] * a `lod` is
* - tensor: [a0, a1] * - LoD: [0 3 5]
* a `lod` is * returns a new LoDTensor
* - LoD: [0 3 5] * - [a0 a0 a0 a1 a1]
* returns a new LoDTensor */
* - [a0 a0 a0 a1 a1] template <typename T>
*/ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) {
template <typename T> LoD abs_lod = ToAbsOffset(lod);
LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, const auto &lod_level = lod[level];
size_t level) { size_t num_instances = source.dims()[0];
LoD abs_lod = ToAbsOffset(lod);
const auto &lod_level = lod[level]; // new tensor
size_t num_instances = source.dims()[0]; LoDTensor tensor;
tensor.set_lod(lod);
// new tensor auto dims = source.dims();
LoDTensor tensor; dims[0] = lod_level.back();
tensor.set_lod(lod); tensor.Resize(dims);
auto dims = source.dims(); tensor.mutable_data<T>();
dims[0] = lod_level.back();
tensor.Resize(dims); // PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
tensor.mutable_data<T>(); for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
// PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1); auto slice = tensor.Slice(elem, elem + 1);
for (size_t ins = 0; ins < num_instances; ins++) { TensorCopy(source.Slice(ins, ins + 1), &slice);
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1];
elem++) {
auto slice = tensor.Slice(elem, elem + 1);
TensorCopy(source.Slice(ins, ins + 1), &slice);
}
}
return tensor;
} }
}
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and return tensor;
// relative length of details for every levels(i.e., [start_level: ]). }
//
// For example, // Get the absolute offset of a lod[start_level][start_idx:end_idx] and
// lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]] // relative length of details for every levels(i.e., [start_level: ]).
// start_level = 0 //
// start_idx = 1 // For example,
// end_idx = 3 // lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]]
// // start_level = 0
// Returns: // start_idx = 1
// LoD = [[1, 4], [2, 4, 2, 3, 2]] // end_idx = 3
// pair<size_t, size_t> = {11, 24} //
std::pair<LoD, std::pair<size_t, size_t>> // Returns:
GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, // LoD = [[1, 4], [2, 4, 2, 3, 2]]
size_t end_idx, size_t start_level); // pair<size_t, size_t> = {11, 24}
std::pair<LoD, std::pair<size_t, size_t>>
void AppendLoD(LoD *lod, const LoD &lod_length); GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, size_t end_idx,
size_t start_level);
/*
* Serialize/Desiralize LoDTensor to std::ostream void AppendLoD(LoD *lod, const LoD &lod_length);
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU. /*
*/ * Serialize/Desiralize LoDTensor to std::ostream
void SerializeToStream(std::ostream &os, const LoDTensor &tensor); * You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
void DeserializeFromStream(std::istream &is, LoDTensor *tensor); */
void SerializeToStream(std::ostream &os, const LoDTensor &tensor);
} // namespace framework
void DeserializeFromStream(std::istream &is, LoDTensor *tensor);
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -5,58 +5,55 @@ ...@@ -5,58 +5,55 @@
#include "op_desc.h" #include "op_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
OpDesc::OpDesc(const proto::OpDesc &desc) : desc_(desc) { OpDesc::OpDesc(const proto::OpDesc &desc) : desc_(desc) {
for (int i = 0; i < desc_.inputs_size(); ++i) { for (int i = 0; i < desc_.inputs_size(); ++i) {
const proto::OpDesc::Var &var = desc_.inputs(i); const proto::OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()]; std::vector<std::string> &args = inputs_[var.parameter()];
int arg_size = var.arguments_size(); int arg_size = var.arguments_size();
for (int j = 0; j < arg_size; ++j) { for (int j = 0; j < arg_size; ++j) {
args.push_back(var.arguments(j)); args.push_back(var.arguments(j));
}
}
for (int i = 0; i < desc_.outputs_size(); ++i) {
const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int arg_size = var.arguments_size();
for (int j = 0; j < arg_size; ++j) {
args.push_back(var.arguments(j));
}
}
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = Attribute::GetAttrValue(attr);
// if (attr.type() == proto::AttrType::INT){
// std::cout << " attrName " << attr_name << " " <<
// attrs_[attr_name].Get<int>() << std::endl;
// }
}
}
} }
}
const std::vector<std::string> &
OpDesc::Input(const std::string &name) const { for (int i = 0; i < desc_.outputs_size(); ++i) {
return inputs_.find(name)->second; const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int arg_size = var.arguments_size();
for (int j = 0; j < arg_size; ++j) {
args.push_back(var.arguments(j));
} }
}
const std::vector<std::string> &
OpDesc::Output(const std::string &name) const { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
return outputs_.find(name)->second; std::string attr_name = attr.name();
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = Attribute::GetAttrValue(attr);
// if (attr.type() == proto::AttrType::INT){
// std::cout << " attrName " << attr_name << " " <<
// attrs_[attr_name].Get<int>() << std::endl;
// }
} }
}
}
Attribute OpDesc::GetAttr(const std::string &name) const { const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
auto it = attrs_.find(name); return inputs_.find(name)->second;
return it->second; }
}
const std::unordered_map<std::string, Attribute> & const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
OpDesc::GetAttrMap() const { return outputs_.find(name)->second;
return attrs_; }
}
Attribute OpDesc::GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
return it->second;
}
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
return attrs_;
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -23,31 +23,29 @@ SOFTWARE. ...@@ -23,31 +23,29 @@ SOFTWARE.
#include "paddle_mobile_object.h" #include "paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class OpDesc : PaddleMobileObject { class OpDesc : PaddleMobileObject {
public: public:
OpDesc(const proto::OpDesc &desc); OpDesc(const proto::OpDesc &desc);
const std::vector<std::string> & const std::vector<std::string> &Input(const std::string &name) const;
Input(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
const std::vector<std::string> & Attribute GetAttr(const std::string &name) const;
Output(const std::string &name) const;
Attribute GetAttr(const std::string &name) const;
const VariableNameMap &GetInputs() { return inputs_; } const VariableNameMap &GetInputs() { return inputs_; }
const VariableNameMap &GetOutputs() { return outputs_; } const VariableNameMap &GetOutputs() { return outputs_; }
const AttributeMap &GetAttrMap() const; const AttributeMap &GetAttrMap() const;
const std::string &Type() { return desc_.type(); }; const std::string &Type() { return desc_.type(); };
private: private:
proto::OpDesc desc_; proto::OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,74 +22,73 @@ SOFTWARE. ...@@ -22,74 +22,73 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> struct OpInfo { template <typename Dtype> struct OpInfo {
OpCreator<Dtype> creator_; OpCreator<Dtype> creator_;
const OpCreator<Dtype> &Creator() const { const OpCreator<Dtype> &Creator() const {
// PADDLE_ENFORCE_NOT_NULL(creator_, // PADDLE_ENFORCE_NOT_NULL(creator_,
// "Operator Creator has not been // "Operator Creator has not been
// registered"); // registered");
return creator_; return creator_;
} }
}; };
template <typename Dtype> class OpInfoMap; template <typename Dtype> class OpInfoMap;
template <typename Dtype> template <typename Dtype> static OpInfoMap<Dtype> *g_op_info_map = nullptr;
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype> class OpInfoMap {
template <typename Dtype> class OpInfoMap { public:
public: static OpInfoMap &Instance() {
static OpInfoMap &Instance() { if (g_op_info_map<Dtype> == nullptr) {
if (g_op_info_map<Dtype> == nullptr) { g_op_info_map<Dtype> = new OpInfoMap();
g_op_info_map<Dtype> = new OpInfoMap(); }
} return *g_op_info_map<Dtype>;
return *g_op_info_map<Dtype>; };
};
bool Has(const std::string &op_type) const {
bool Has(const std::string &op_type) const { return map_.find(op_type) != map_.end();
return map_.find(op_type) != map_.end(); }
}
void Insert(const std::string &type, const OpInfo<Dtype> &info) {
void Insert(const std::string &type, const OpInfo<Dtype> &info) { // PADDLE_ENFORCE(!Has(type), "Operator %s has been
// PADDLE_ENFORCE(!Has(type), "Operator %s has been // registered", type);
// registered", type); map_.insert({type, info});
map_.insert({type, info}); }
}
const OpInfo<Dtype> &Get(const std::string &type) const {
const OpInfo<Dtype> &Get(const std::string &type) const { auto op_info_ptr = GetNullable(type);
auto op_info_ptr = GetNullable(type); // PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not // been
// been // registered",
// registered", // type);
// type); return *op_info_ptr;
return *op_info_ptr; }
}
const OpInfo<Dtype> *GetNullable(const std::string &type) const {
const OpInfo<Dtype> *GetNullable(const std::string &type) const { auto it = map_.find(type);
auto it = map_.find(type); if (it == map_.end()) {
if (it == map_.end()) { return nullptr;
return nullptr; } else {
} else { return &it->second;
return &it->second; }
} }
}
const std::unordered_map<std::string, OpInfo<Dtype>> &map() const {
const std::unordered_map<std::string, OpInfo<Dtype>> &map() const { return map_;
return map_; }
}
std::unordered_map<std::string, OpInfo<Dtype>> *mutable_map() {
std::unordered_map<std::string, OpInfo<Dtype>> *mutable_map() { return &map_;
return &map_; }
}
private:
private: OpInfoMap() = default;
OpInfoMap() = default; std::unordered_map<std::string, OpInfo<Dtype>> map_;
std::unordered_map<std::string, OpInfo<Dtype>> map_;
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
// DISABLE_COPY_AND_ASSIGN(OpInfoMap); };
};
} // namespace framework
} // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,51 +22,44 @@ SOFTWARE. ...@@ -22,51 +22,44 @@ SOFTWARE.
#include "framework.pb.h" #include "framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
struct OpKernelType { struct OpKernelType {
struct Hash { struct Hash {
size_t operator()(const OpKernelType &key) const { size_t operator()(const OpKernelType &key) const {
int data_type = static_cast<int>(key.data_type_) int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
<< LEFT_SHIFT; int data_layout = static_cast<int>(key.data_layout_)
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
<< (LEFT_SHIFT * 2);
std::hash<int> hasher; std::hash<int> hasher;
return hasher(data_type + data_layout); return hasher(data_type + data_layout);
} }
}; };
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::VarType::Type data_type_; proto::VarType::Type data_type_;
DataLayout data_layout_; DataLayout data_layout_;
OpKernelType(proto::VarType::Type data_type, OpKernelType(proto::VarType::Type data_type,
DataLayout data_layout = DataLayout::kAnyLayout) DataLayout data_layout = DataLayout::kAnyLayout)
: data_type_(data_type), data_layout_(data_layout) {} : data_type_(data_type), data_layout_(data_layout) {}
bool operator==(const OpKernelType &o) const { bool operator==(const OpKernelType &o) const {
return data_type_ == o.data_type_ && return data_type_ == o.data_type_ && data_layout_ == o.data_layout_;
data_layout_ == o.data_layout_; }
}
bool operator!=(const OpKernelType &o) const { bool operator!=(const OpKernelType &o) const { return !(*this == o); }
return !(*this == o); };
}
};
inline bool NeedTransformLayout(const DataLayout &l, inline bool NeedTransformLayout(const DataLayout &l, const DataLayout &r) {
const DataLayout &r) { return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r;
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && }
l != r;
}
inline bool TransFromNeeded(const OpKernelType &l, inline bool TransFromNeeded(const OpKernelType &l, const OpKernelType &r) {
const OpKernelType &r) { return (l.data_type_ != r.data_type_) ||
return (l.data_type_ != r.data_type_) || NeedTransformLayout(l.data_layout_, r.data_layout_);
NeedTransformLayout(l.data_layout_, r.data_layout_); }
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -19,8 +19,8 @@ SOFTWARE. ...@@ -19,8 +19,8 @@ SOFTWARE.
#pragma once #pragma once
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {}; class OpProtoAndCheckerMaker {};
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -20,23 +20,23 @@ SOFTWARE. ...@@ -20,23 +20,23 @@ SOFTWARE.
#include "op_info.h" #include "op_info.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
template <typename Dtype> template <typename Dtype>
OperatorBase<Dtype>::OperatorBase(const std::string &type, OperatorBase<Dtype>::OperatorBase(const std::string &type,
const VariableNameMap &inputs, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs), : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs),
scope_(scope) { scope_(scope) {
CheckAllInputOutputSet(); CheckAllInputOutputSet();
} }
template <typename Dtype> template <typename Dtype>
void OperatorBase<Dtype>::CheckAllInputOutputSet() const {} void OperatorBase<Dtype>::CheckAllInputOutputSet() const {}
template class OperatorBase<CPU>; template class OperatorBase<CPU>;
template class OperatorWithKernel<CPU>; template class OperatorWithKernel<CPU>;
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
...@@ -23,14 +23,14 @@ SOFTWARE. ...@@ -23,14 +23,14 @@ SOFTWARE.
namespace paddle_mobile { namespace paddle_mobile {
class PaddleMobileObject { class PaddleMobileObject {
public: public:
virtual std::string ToString() { virtual std::string ToString() {
char address[128] = {0}; char address[128] = {0};
sprintf(address, "%p", this); sprintf(address, "%p", this);
return std::string(address); return std::string(address);
} }
private: private:
}; };
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -25,22 +25,22 @@ SOFTWARE. ...@@ -25,22 +25,22 @@ SOFTWARE.
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
class Node : PaddleMobileObject { class Node : PaddleMobileObject {
public: public:
Node(const std::string &type) : type_(type) {} Node(const std::string &type) : type_(type) {}
Node &operator>(const Node &out); Node &operator>(const Node &out);
bool operator==(const Node &in); bool operator==(const Node &in);
std::string ToString() const; std::string ToString() const;
private: private:
std::string ToString(std::string blank) const; std::string ToString(std::string blank) const;
std::vector<std::shared_ptr<Node>> outputs_; std::vector<std::shared_ptr<Node>> outputs_;
std::string type_; std::string type_;
}; };
Print &operator<<(Print &printer, const Node &node); Print &operator<<(Print &printer, const Node &node);
} } // namespace framework
} } // namespace paddle_mobile
...@@ -19,7 +19,7 @@ SOFTWARE. ...@@ -19,7 +19,7 @@ SOFTWARE.
#include "program_optimize.h" #include "program_optimize.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {} std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
} } // namespace framework
} } // namespace paddle_mobile
...@@ -17,5 +17,5 @@ SOFTWARE. ...@@ -17,5 +17,5 @@ SOFTWARE.
==============================================================================*/ ==============================================================================*/
namespace paddle_mobile { namespace paddle_mobile {
namespace framework {} namespace framework {}
} // namespace paddle_mobile } // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include "framework/scope.h" #include "framework/scope.h"
#include "framework/tensor.h" #include "framework/tensor.h"
#include "framework/variable.h" #include "framework/variable.h"
#include "framework/variable.h"
#include "io.h" #include "io.h"
#include "test_helper.h" #include "test_helper.h"
#include <map> #include <map>
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册