diff --git a/CMakeLists.txt b/CMakeLists.txt index 6bc6a8077c1e9b242af50f1b912e91bbc6c3a9fb..fb1c85bf742c80308edb009c080cb0da6d409ee0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING) find_package(Git REQUIRED) find_package(Threads REQUIRED) -find_package(Boost QUIET) +if(NOT ANDROID) + find_package(Boost QUIET) +endif() include(simd) @@ -151,7 +153,9 @@ if(WITH_GOLANG) endif(WITH_GOLANG) add_subdirectory(paddle) -add_subdirectory(python) +if(WITH_PYTHON) + add_subdirectory(python) +endif() if(WITH_DOC) add_subdirectory(doc) endif() diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake index 9724c16122ab2e6be55864c8716698c9b9d7c3f0..dcfbc5d0129d7763daaf17c33d2b7791e87d3018 100644 --- a/cmake/cross_compiling/android.cmake +++ b/cmake/cross_compiling/android.cmake @@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") SET(CMAKE_SYSTEM_PROCESSOR armv7-a) ENDIF() ENDIF() + IF(ANDROID_ABI STREQUAL "arm64-v8a") + SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android) + ENDIF() SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") ENDIF() @@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ENDIF() ENDIF() + IF(ANDROID_ABI STREQUAL "arm64-v8a") + LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a) + ENDIF() + STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}") diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 5b9d9844ed21ceb507a8e01676c3533f4e3dd8fb..60a1041936437775e0994157b8ffcb7c52b7ab87 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND}) # arm_soft_fp_abi branch of OpenBLAS to support softfp # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") - SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) + IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$") + SET(TARGET "ARMV7") + ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a") + SET(TARGET "ARMV8") + ENDIF() + SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0) ELSEIF(RPI) # use hardfp SET(OPENBLAS_COMMIT "v0.2.19") diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 71ee266611df27cd72c586c49ce2765e6ea0471b..716955c7b42f3d05b3ec8387cf81dd9cb1c409bf 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -90,11 +90,11 @@ # including binary directory for generated headers. include_directories(${CMAKE_CURRENT_BINARY_DIR}) -if(NOT APPLE) +if(NOT APPLE AND NOT ANDROID) find_package(Threads REQUIRED) link_libraries(${CMAKE_THREAD_LIBS_INIT}) set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt") -endif(NOT APPLE) +endif(NOT APPLE AND NOT ANDROID) function(merge_static_libs TARGET_NAME) set(libs ${ARGN}) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 48351ab6d016b3594a4d983ba0980b2b0eb11c89..b331b8126cadc2c5df516fb241913415b2e3e73d 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -34,7 +34,7 @@ func main() { var idx int - var cp *pserver.Checkpoint + var cp pserver.Checkpoint var e *pserver.EtcdClient if *index >= 0 { idx = *index diff --git a/go/pserver/service.go b/go/pserver/service.go index 65db6970a703606cb74dc9f7677757f16f1e7181..fec2ec61dc67756439d9fa478788d1f155876538 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -83,7 +83,7 @@ type parameterCheckpoint struct { } // NewCheckpointFromFile loads parameters and state from checkpoint file -func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) { +func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) { v, err := e.GetKey(PsPath+string(idx), 3*time.Second) if err != nil { return nil, err @@ -110,7 +110,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, } dec := gob.NewDecoder(bytes.NewReader(content)) - cp := &Checkpoint{} + cp := Checkpoint{} if err = dec.Decode(cp); err != nil { return nil, err } @@ -119,7 +119,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, // NewService creates a new service, will bypass etcd registration if no // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. -func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) { +func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { s := &Service{ idx: idx, checkpointInterval: interval, @@ -130,7 +130,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient s.initialized = make(chan struct{}) if cp != nil { - for _, item := range *cp { + for _, item := range cp { p := ParameterWithConfig{ Param: item.Param, Config: item.Config, diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 248c7a1a3b866ae3bf2af33d0ff67b92d0f9c456..e46da822c6c2e99b682212dc724126ac47150a24 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -119,6 +119,7 @@ class OpRegistry { op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); } op_checkers().at(op_type).Check(op->attrs_); + op->Init(); return op; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0ce422e007c39ce0c3f5f7a89650cc211919ea8f..4336115670031fed3672e799340241c9d6f45d5f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -49,6 +49,10 @@ class OperatorBase { std::string DebugString() const; + /// Init will be called after CreateOperator, you can put some initialization + /// logic here. + virtual void Init() {} + /// InferShape infer the size of Variables used by this Operator with /// information inside scope virtual void InferShape(const std::shared_ptr& scope) const = 0; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index be8c4be2d429648b3c8a708c7f8bdcae3ff2d283..01b87bb50e7a10134fc43fa6aa2716e13c4dbe33 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -21,14 +21,19 @@ namespace framework { class OperatorTest : public OperatorBase { public: + void Init() override { x = 1; } void InferShape(const std::shared_ptr& scope) const override {} void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override { float scale = GetAttr("scale"); ASSERT_NEAR(scale, 3.14, 1e-5); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_EQ(x, 1); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); } + + public: + float x = 0; }; class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {