提交 0b0b3ba1 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into tensor_to_EigenTensor

...@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET) if(NOT ANDROID)
find_package(Boost QUIET)
endif()
include(simd) include(simd)
...@@ -151,7 +153,9 @@ if(WITH_GOLANG) ...@@ -151,7 +153,9 @@ if(WITH_GOLANG)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
...@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
SET(CMAKE_SYSTEM_PROCESSOR armv7-a) SET(CMAKE_SYSTEM_PROCESSOR armv7-a)
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF() ENDIF()
...@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0") ...@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}") STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}")
......
...@@ -52,6 +52,7 @@ ExternalProject_Add( ...@@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog) ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog) LIST(APPEND external_project_dependencies glog)
...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND})
# arm_soft_fp_abi branch of OpenBLAS to support softfp # arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(TARGET "ARMV7")
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(TARGET "ARMV8")
ENDIF()
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI) ELSEIF(RPI)
# use hardfp # use hardfp
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
......
...@@ -90,11 +90,11 @@ ...@@ -90,11 +90,11 @@
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE) endif(NOT APPLE AND NOT ANDROID)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
set(libs ${ARGN}) set(libs ${ARGN})
......
...@@ -34,7 +34,7 @@ func main() { ...@@ -34,7 +34,7 @@ func main() {
var idx int var idx int
var cp *pserver.Checkpoint var cp pserver.Checkpoint
var e *pserver.EtcdClient var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
......
...@@ -83,7 +83,7 @@ type parameterCheckpoint struct { ...@@ -83,7 +83,7 @@ type parameterCheckpoint struct {
} }
// NewCheckpointFromFile loads parameters and state from checkpoint file // NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) { func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second) v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -110,7 +110,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, ...@@ -110,7 +110,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint,
} }
dec := gob.NewDecoder(bytes.NewReader(content)) dec := gob.NewDecoder(bytes.NewReader(content))
cp := &Checkpoint{} cp := Checkpoint{}
if err = dec.Decode(cp); err != nil { if err = dec.Decode(cp); err != nil {
return nil, err return nil, err
} }
...@@ -119,7 +119,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, ...@@ -119,7 +119,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint,
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: interval, checkpointInterval: interval,
...@@ -130,7 +130,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient ...@@ -130,7 +130,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
if cp != nil { if cp != nil {
for _, item := range *cp { for _, item := range cp {
p := ParameterWithConfig{ p := ParameterWithConfig{
Param: item.Param, Param: item.Param,
Config: item.Config, Config: item.Config,
......
...@@ -14,6 +14,7 @@ if(Boost_FOUND) ...@@ -14,6 +14,7 @@ if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind) add_subdirectory(pybind)
endif() endif()
......
...@@ -11,8 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) ...@@ -11,8 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
......
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
...@@ -81,8 +82,6 @@ class OpProtoAndCheckerMaker { ...@@ -81,8 +82,6 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name); return op_checker_->AddAttrChecker<T>(name);
} }
void AddType(const std::string& op_type) { proto_->set_type(op_type); }
void AddComment(const std::string& comment) { void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment; *(proto_->mutable_comment()) = comment;
} }
...@@ -101,8 +100,11 @@ class OpRegistry { ...@@ -101,8 +100,11 @@ class OpRegistry {
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker); ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized(), *op_proto.mutable_type() = op_type;
"Fail to initialize %s's OpProto !", op_type); PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
} }
static OperatorBase* CreateOp(const OpDesc& op_desc) { static OperatorBase* CreateOp(const OpDesc& op_desc) {
...@@ -119,6 +121,7 @@ class OpRegistry { ...@@ -119,6 +121,7 @@ class OpRegistry {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
} }
op_checkers().at(op_type).Check(op->attrs_); op_checkers().at(op_type).Check(op->attrs_);
op->Init();
return op; return op;
} }
...@@ -142,18 +145,73 @@ class OpRegistry { ...@@ -142,18 +145,73 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper { class OpRegisterHelper {
public: public:
OpRegisterHelper(std::string op_type) { OpRegisterHelper(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
} }
}; };
#define REGISTER_OP(type, op_class, op_maker_class) \ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
class op_class##Register { \ struct __test_global_namespace_##uniq_name##__ {}; \
private: \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
const static OpRegisterHelper<op_class, op_maker_class> reg; \ __test_global_namespace_##uniq_name##__>::value, \
}; \ msg)
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
#type) #define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
"REGISTER_OP must be in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; }
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
#define USE_OP_WITHOUT_KERNEL(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_without_kernel_##op_type, \
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
extern int __op_register_##op_type##_handle__(); \
static int __use_op_ptr_##op_type##_without_kernel__ \
__attribute__((unused)) = __op_register_##op_type##_handle__()
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_KERNEL must be in global namespace"); \
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
__attribute__((unused)) = \
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU);
#else
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU); \
USE_OP_KERNEL(op_type, GPU)
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
using namespace paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
...@@ -21,13 +19,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -21,13 +19,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
AddType("cos");
AddComment("This is cos op"); AddComment("This is cos op");
} }
}; };
REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker);
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
...@@ -48,15 +43,17 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -48,15 +43,17 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}; };
AddAttr<int>("test_attr", "a simple test attribute") AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker); .AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op"); AddComment("This is my_test op");
} }
}; };
REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
...@@ -71,7 +68,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -71,7 +68,7 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale"); float scale_get = op->GetAttr<float>("scale");
...@@ -114,7 +111,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -114,7 +111,7 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
...@@ -169,13 +166,8 @@ TEST(OpRegistry, CustomChecker) { ...@@ -169,13 +166,8 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
\ No newline at end of file
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <paddle/framework/attr_checker.h> #include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h> #include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h> #include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h> #include <paddle/platform/device_context.h>
#include <paddle/platform/place.h> #include <paddle/platform/place.h>
#include <paddle/utils/Error.h> #include <paddle/utils/Error.h>
...@@ -49,6 +50,10 @@ class OperatorBase { ...@@ -49,6 +50,10 @@ class OperatorBase {
std::string DebugString() const; std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
...@@ -99,6 +104,19 @@ class OpKernel { ...@@ -99,6 +104,19 @@ class OpKernel {
virtual ~OpKernel() {} virtual ~OpKernel() {}
}; };
template <typename T>
struct VarToTensor {};
template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};
template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
struct OpKernelKey { struct OpKernelKey {
...@@ -132,19 +150,36 @@ class OperatorWithKernel : public OperatorBase { ...@@ -132,19 +150,36 @@ class OperatorWithKernel : public OperatorBase {
AllOpKernels() { AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels; return g_all_op_kernels;
}
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
}; };
private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}
protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
...@@ -21,54 +21,21 @@ namespace framework { ...@@ -21,54 +21,21 @@ namespace framework {
class OperatorTest : public OperatorBase { class OperatorTest : public OperatorBase {
public: public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale"); float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5); ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
} }
};
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) float x = 0;
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
}; };
REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
OpDesc op_desc;
op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);
platform::CPUDeviceContext device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
delete op;
}
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
...@@ -78,14 +45,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -78,14 +45,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
public: protected:
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
}; };
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel {
...@@ -98,10 +65,16 @@ class CPUKernelTest : public OpKernel { ...@@ -98,10 +65,16 @@ class CPUKernelTest : public OpKernel {
} }
}; };
REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker); } // namespace framework
REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest); } // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
TEST(OpKernel, all) { TEST(OpKernel, all) {
using namespace paddle::framework;
OpDesc op_desc; OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1"; *op_desc.mutable_inputs()->Add() = "IN1";
...@@ -111,7 +84,7 @@ TEST(OpKernel, all) { ...@@ -111,7 +84,7 @@ TEST(OpKernel, all) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
...@@ -119,5 +92,3 @@ TEST(OpKernel, all) { ...@@ -119,5 +92,3 @@ TEST(OpKernel, all) {
delete op; delete op;
} }
} // namespace framework
} // namespace paddle
\ No newline at end of file
if(WITH_GPU)
nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim)
else()
cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim)
endif()
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
#include <paddle/framework/op_registry.h>
#include <paddle/framework/tensor.h>
#include <paddle/operators/add_op.h>
namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of Add Op's dimension must be same.");
// Need set dims in Tensor
// outputs[0]->set_dims(inputs[0]->dims())
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Two Element Add Operator.
The equation is: Out = X + Y
)DOC");
}
};
} // namespace op
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>);
\ No newline at end of file
#include <paddle/operators/add_op.h>
#include <paddle/framework/op_registry.h>
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
\ No newline at end of file
#pragma once
#include <glog/logging.h>
#include <paddle/framework/operator.h>
namespace paddle {
namespace operators {
template <typename Place>
class AddKernel : public framework::OpKernel {
public:
void Compute(const KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name();
}
};
} // namespace op
} // namespace paddle
#include <gtest/gtest.h>
#define private public
#include <paddle/framework/op_registry.h>
USE_OP(add_two);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
}
\ No newline at end of file
...@@ -67,20 +67,20 @@ extern void *cublas_dso_handle; ...@@ -67,20 +67,20 @@ extern void *cublas_dso_handle;
__macro(cublasSgemm); \ __macro(cublasSgemm); \
__macro(cublasDgemm); \ __macro(cublasDgemm); \
__macro(cublasSgeam); \ __macro(cublasSgeam); \
__macro(cublasDgeam); __macro(cublasDgeam); \
__macro(cublasCreate_v2); \
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate); __macro(cublasDestroy_v2); \
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy); __macro(cublasSetStream_v2); \
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream); __macro(cublasSetPointerMode_v2); \
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode); __macro(cublasGetPointerMode_v2); \
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode); __macro(cublasSgemmBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched); __macro(cublasDgemmBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched); __macro(cublasCgemmBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched); __macro(cublasZgemmBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched); __macro(cublasSgetrfBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched); __macro(cublasSgetriBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched); __macro(cublasDgetrfBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched); __macro(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP); CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP);
......
"""
Default scope function.
`Paddle` manages Scope as programming language's scope. It just a
thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent.
Invoking `create_var/get_var` can `create/get` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope.
A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""
import paddle.v2.framework.core
import threading
__tl_scope__ = threading.local()
__all__ = [
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var',
'get_var', 'scoped_function'
]
def get_cur_scope():
"""
Get current scope.
:rtype: paddle.v2.framework.core.Scope
"""
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
if cur_scope_stack is None:
__tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None))
return __tl_scope__.cur_scope[-1]
def enter_local_scope():
"""
Enter a new local scope
"""
cur_scope = get_cur_scope()
new_scope = paddle.v2.framework.core.Scope(cur_scope)
__tl_scope__.cur_scope.append(new_scope)
def leave_local_scope():
"""
Leave local scope
"""
__tl_scope__.cur_scope.pop()
def create_var(name):
"""
create variable in current scope.
"""
return get_cur_scope().create_var(name)
def get_var(name):
"""
get variable in current scope.
"""
return get_cur_scope().get_var(name)
def scoped_function(func):
"""
invoke `func` in new scope.
:param func: a callable function that will be run in new scope.
:type func: callable
"""
enter_local_scope()
try:
func()
except:
raise
finally:
leave_local_scope()
add_python_test(test_framework test_protobuf.py test_scope.py) add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py)
from paddle.v2.framework.default_scope_funcs import *
import unittest
class TestDefaultScopeFuncs(unittest.TestCase):
def test_cur_scope(self):
self.assertIsNotNone(get_cur_scope())
def test_none_variable(self):
self.assertIsNone(get_var("test"))
def test_create_var_get_var(self):
var_a = create_var("var_a")
self.assertIsNotNone(var_a)
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
enter_local_scope()
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
leave_local_scope()
def test_var_get_int(self):
def __new_scope__():
i = create_var("var_i")
self.assertFalse(i.is_int())
i.set_int(10)
self.assertTrue(i.is_int())
self.assertEqual(10, i.get_int())
for _ in xrange(10):
scoped_function(__new_scope__)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册