提交 0b0b3ba1 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into tensor_to_EigenTensor

......@@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED)
find_package(Threads REQUIRED)
find_package(Boost QUIET)
if(NOT ANDROID)
find_package(Boost QUIET)
endif()
include(simd)
......@@ -151,7 +153,9 @@ if(WITH_GOLANG)
endif(WITH_GOLANG)
add_subdirectory(paddle)
add_subdirectory(python)
if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC)
add_subdirectory(doc)
endif()
......@@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
SET(CMAKE_SYSTEM_PROCESSOR armv7-a)
ENDIF()
ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF()
......@@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF()
ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}")
......
......@@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog)
ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog)
......@@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND})
# arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0)
IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(TARGET "ARMV7")
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(TARGET "ARMV8")
ENDIF()
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI)
# use hardfp
SET(OPENBLAS_COMMIT "v0.2.19")
......
......@@ -90,11 +90,11 @@
# including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE)
if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE)
endif(NOT APPLE AND NOT ANDROID)
function(merge_static_libs TARGET_NAME)
set(libs ${ARGN})
......
......@@ -34,7 +34,7 @@ func main() {
var idx int
var cp *pserver.Checkpoint
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
......
......@@ -83,7 +83,7 @@ type parameterCheckpoint struct {
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) {
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil {
return nil, err
......@@ -110,7 +110,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint,
}
dec := gob.NewDecoder(bytes.NewReader(content))
cp := &Checkpoint{}
cp := Checkpoint{}
if err = dec.Decode(cp); err != nil {
return nil, err
}
......@@ -119,7 +119,7 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint,
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) {
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
......@@ -130,7 +130,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
s.initialized = make(chan struct{})
if cp != nil {
for _, item := range *cp {
for _, item := range cp {
p := ParameterWithConfig{
Param: item.Param,
Config: item.Config,
......
......@@ -14,6 +14,7 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()
......
......@@ -11,8 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place)
cc_library(operator SRCS operator.cc DEPS op_desc device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
......
#pragma once
#include <algorithm>
#include <type_traits>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
......@@ -81,8 +82,6 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name);
}
void AddType(const std::string& op_type) { proto_->set_type(op_type); }
void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment;
}
......@@ -101,8 +100,11 @@ class OpRegistry {
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized(),
"Fail to initialize %s's OpProto !", op_type);
*op_proto.mutable_type() = op_type;
PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
}
static OperatorBase* CreateOp(const OpDesc& op_desc) {
......@@ -119,6 +121,7 @@ class OpRegistry {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);
op->Init();
return op;
}
......@@ -142,18 +145,73 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
public:
OpRegisterHelper(std::string op_type) {
OpRegisterHelper(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
}
};
#define REGISTER_OP(type, op_class, op_maker_class) \
class op_class##Register { \
private: \
const static OpRegisterHelper<op_class, op_maker_class> reg; \
}; \
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
#type)
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
"REGISTER_OP must be in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; }
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
#define USE_OP_WITHOUT_KERNEL(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_without_kernel_##op_type, \
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
extern int __op_register_##op_type##_handle__(); \
static int __use_op_ptr_##op_type##_without_kernel__ \
__attribute__((unused)) = __op_register_##op_type##_handle__()
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_KERNEL must be in global namespace"); \
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
__attribute__((unused)) = \
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU);
#else
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU); \
USE_OP_KERNEL(op_type, GPU)
#endif
} // namespace framework
} // namespace paddle
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
using namespace paddle::framework;
namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
......@@ -21,13 +19,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};
REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker);
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
......@@ -48,15 +43,17 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker);
} // namespace framework
} // namespace paddle
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
......@@ -71,7 +68,7 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
......@@ -114,7 +111,7 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
......@@ -169,13 +166,8 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
\ No newline at end of file
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
......@@ -49,6 +50,10 @@ class OperatorBase {
std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
......@@ -99,6 +104,19 @@ class OpKernel {
virtual ~OpKernel() {}
};
template <typename T>
struct VarToTensor {};
template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};
template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};
class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
......@@ -132,19 +150,36 @@ class OperatorWithKernel : public OperatorBase {
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};
private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}
protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
};
} // namespace framework
} // namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
......@@ -21,54 +21,21 @@ namespace framework {
class OperatorTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
}
};
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
float x = 0;
};
REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
OpDesc op_desc;
op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);
platform::CPUDeviceContext device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
delete op;
}
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -78,14 +45,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
};
class OpWithKernelTest : public OperatorWithKernel {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
protected:
void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
};
class CPUKernelTest : public OpKernel {
......@@ -98,10 +65,16 @@ class CPUKernelTest : public OpKernel {
}
};
REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest);
} // namespace framework
} // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
TEST(OpKernel, all) {
using namespace paddle::framework;
OpDesc op_desc;
op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1";
......@@ -111,7 +84,7 @@ TEST(OpKernel, all) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);
platform::CPUDeviceContext cpu_device_context;
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
......@@ -119,5 +92,3 @@ TEST(OpKernel, all) {
delete op;
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
if(WITH_GPU)
nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim)
else()
cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim)
endif()
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
#include <paddle/framework/op_registry.h>
#include <paddle/framework/tensor.h>
#include <paddle/operators/add_op.h>
namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of Add Op's dimension must be same.");
// Need set dims in Tensor
// outputs[0]->set_dims(inputs[0]->dims())
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Two Element Add Operator.
The equation is: Out = X + Y
)DOC");
}
};
} // namespace op
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>);
\ No newline at end of file
#include <paddle/operators/add_op.h>
#include <paddle/framework/op_registry.h>
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
\ No newline at end of file
#pragma once
#include <glog/logging.h>
#include <paddle/framework/operator.h>
namespace paddle {
namespace operators {
template <typename Place>
class AddKernel : public framework::OpKernel {
public:
void Compute(const KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name();
}
};
} // namespace op
} // namespace paddle
#include <gtest/gtest.h>
#define private public
#include <paddle/framework/op_registry.h>
USE_OP(add_two);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
}
\ No newline at end of file
......@@ -67,20 +67,20 @@ extern void *cublas_dso_handle;
__macro(cublasSgemm); \
__macro(cublasDgemm); \
__macro(cublasSgeam); \
__macro(cublasDgeam);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched);
__macro(cublasDgeam); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
__macro(cublasSetPointerMode_v2); \
__macro(cublasGetPointerMode_v2); \
__macro(cublasSgemmBatched); \
__macro(cublasDgemmBatched); \
__macro(cublasCgemmBatched); \
__macro(cublasZgemmBatched); \
__macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP);
......
"""
Default scope function.
`Paddle` manages Scope as programming language's scope. It just a
thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent.
Invoking `create_var/get_var` can `create/get` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope.
A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""
import paddle.v2.framework.core
import threading
__tl_scope__ = threading.local()
__all__ = [
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var',
'get_var', 'scoped_function'
]
def get_cur_scope():
"""
Get current scope.
:rtype: paddle.v2.framework.core.Scope
"""
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
if cur_scope_stack is None:
__tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None))
return __tl_scope__.cur_scope[-1]
def enter_local_scope():
"""
Enter a new local scope
"""
cur_scope = get_cur_scope()
new_scope = paddle.v2.framework.core.Scope(cur_scope)
__tl_scope__.cur_scope.append(new_scope)
def leave_local_scope():
"""
Leave local scope
"""
__tl_scope__.cur_scope.pop()
def create_var(name):
"""
create variable in current scope.
"""
return get_cur_scope().create_var(name)
def get_var(name):
"""
get variable in current scope.
"""
return get_cur_scope().get_var(name)
def scoped_function(func):
"""
invoke `func` in new scope.
:param func: a callable function that will be run in new scope.
:type func: callable
"""
enter_local_scope()
try:
func()
except:
raise
finally:
leave_local_scope()
add_python_test(test_framework test_protobuf.py test_scope.py)
add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py)
from paddle.v2.framework.default_scope_funcs import *
import unittest
class TestDefaultScopeFuncs(unittest.TestCase):
def test_cur_scope(self):
self.assertIsNotNone(get_cur_scope())
def test_none_variable(self):
self.assertIsNone(get_var("test"))
def test_create_var_get_var(self):
var_a = create_var("var_a")
self.assertIsNotNone(var_a)
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
enter_local_scope()
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
leave_local_scope()
def test_var_get_int(self):
def __new_scope__():
i = create_var("var_i")
self.assertFalse(i.is_int())
i.set_int(10)
self.assertTrue(i.is_int())
self.assertEqual(10, i.get_int())
for _ in xrange(10):
scoped_function(__new_scope__)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册