提交 f6fe3715 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into fix_bug_dynload

...@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET) if(NOT ANDROID)
find_package(Boost QUIET)
endif()
include(simd) include(simd)
...@@ -151,7 +153,9 @@ if(WITH_GOLANG) ...@@ -151,7 +153,9 @@ if(WITH_GOLANG)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
...@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
SET(CMAKE_SYSTEM_PROCESSOR armv7-a) SET(CMAKE_SYSTEM_PROCESSOR armv7-a)
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF() ENDIF()
...@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}") STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}")
......
...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND})
# arm_soft_fp_abi branch of OpenBLAS to support softfp # arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(TARGET "ARMV7")
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(TARGET "ARMV8")
ENDIF()
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI) ELSEIF(RPI)
# use hardfp # use hardfp
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
......
...@@ -90,11 +90,11 @@ ...@@ -90,11 +90,11 @@
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE) endif(NOT APPLE AND NOT ANDROID)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
set(libs ${ARGN}) set(libs ${ARGN})
......
...@@ -34,7 +34,7 @@ func main() { ...@@ -34,7 +34,7 @@ func main() {
var idx int var idx int
var cp *pserver.Checkpoint var cp pserver.Checkpoint
var e *pserver.EtcdClient var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
......
...@@ -83,7 +83,7 @@ type parameterCheckpoint struct { ...@@ -83,7 +83,7 @@ type parameterCheckpoint struct {
} }
// NewCheckpointFromFile loads parameters and state from checkpoint file // NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) { func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second) v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -110,7 +110,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, ...@@ -110,7 +110,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint,
} }
dec := gob.NewDecoder(bytes.NewReader(content)) dec := gob.NewDecoder(bytes.NewReader(content))
cp := &Checkpoint{} cp := Checkpoint{}
if err = dec.Decode(cp); err != nil { if err = dec.Decode(cp); err != nil {
return nil, err return nil, err
} }
...@@ -119,7 +119,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, ...@@ -119,7 +119,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint,
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: interval, checkpointInterval: interval,
...@@ -130,7 +130,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -130,7 +130,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
if cp != nil { if cp != nil {
for _, item := range *cp { for _, item := range cp {
p := ParameterWithConfig{ p := ParameterWithConfig{
Param: item.Param, Param: item.Param,
Config: item.Config, Config: item.Config,
......
...@@ -119,6 +119,7 @@ class OpRegistry { ...@@ -119,6 +119,7 @@ class OpRegistry {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
} }
op_checkers().at(op_type).Check(op->attrs_); op_checkers().at(op_type).Check(op->attrs_);
op->Init();
return op; return op;
} }
......
...@@ -49,6 +49,10 @@ class OperatorBase { ...@@ -49,6 +49,10 @@ class OperatorBase {
std::string DebugString() const; std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
......
...@@ -21,14 +21,19 @@ namespace framework { ...@@ -21,14 +21,19 @@ namespace framework {
class OperatorTest : public OperatorBase { class OperatorTest : public OperatorBase {
public: public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale"); float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5); ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
} }
public:
float x = 0;
}; };
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册