diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 6e648cb53a2b7056d6acf010b6c7f251ac023459..2c1335d2054ce43b2a2b70ec09db4e127328badf 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -93,7 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) if(NOT APPLE AND NOT ANDROID) find_package(Threads REQUIRED) link_libraries(${CMAKE_THREAD_LIBS_INIT}) - set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl") + set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt") endif(NOT APPLE) function(merge_static_libs TARGET_NAME) diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 2c1eb7521d896e6d67021e4d020f274bb32d123c..58a35564f83928ee0bdaad63b154ed57d8d8a735 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,7 +15,6 @@ if(Boost_FOUND) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) - add_subdirectory(operators) add_subdirectory(pybind) endif() diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index aac49fdb7a04ac566ad24c6d17f9af991241e45b..b8642ca22ab340cade5ded62b6e1b5d38680869d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf) -cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 02c99d50bb50cbd49a56a2282e55c148d4e6af16..248c7a1a3b866ae3bf2af33d0ff67b92d0f9c456 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -147,13 +147,13 @@ class OpRegisterHelper { } }; -#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ - class __op_class##Register { \ - private: \ - const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ - }; \ - const OpRegisterHelper<__op_class, __op_maker_class> \ - __op_class##Register::reg(#__op_type); +#define REGISTER_OP(type, op_class, op_maker_class) \ + class op_class##Register { \ + private: \ + const static OpRegisterHelper reg; \ + }; \ + const OpRegisterHelper op_class##Register::reg( \ + #type) } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index c4baafc2aebc8d009a388635bbab180d86a4b914..f5162fb870a91e566a0d2b1419050fe0799b199b 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,17 +1,15 @@ #include "paddle/framework/op_registry.h" #include -#include "paddle/framework/operator.h" -#include "paddle/operators/demo_op.h" using namespace paddle::framework; namespace paddle { namespace framework { -class CosineOp : public OperatorWithKernel { +class CosineOp : public OperatorBase { public: - void Run(const OpRunContext* context) const override { - printf("%s\n", DebugString().c_str()); - } + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + void InferShape(const std::shared_ptr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -28,14 +26,15 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } }; -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) +REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker); + +class MyTestOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} -class MyTestOp : public OperatorWithKernel { public: - void Run(const OpRunContext* ctx) const override { - printf("%s\n", DebugString().c_str()); - printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); - } }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -54,7 +53,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } }; -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) +REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker); } // namespace framework } // namespace paddle @@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); float scale_get = op->GetAttr("scale"); ASSERT_EQ(scale_get, scale); } @@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); } @@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(4); paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto dev_ctx = DeviceContext(); + paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); - op->Run(scope, &dev_ctx); + op->Run(scope, dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3db3706e47dfab49f54b3f1f9f2e41c53fc3f298..8f7adff8b3982e91a3d7f6d598cd62d5005d5f17 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const { return ss.str(); } -const Variable* OpRunContext::Input(int index) const { - return scope_->GetVariable(op_->inputs_[index]); -} - -Variable* OpRunContext::Output(int index) const { - return scope_->GetVariable(op_->outputs_[index]); -} - } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6570d5869814a198195968606d041055e847ca08..0ce422e007c39ce0c3f5f7a89650cc211919ea8f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,44 +14,22 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include +#include +#include #include #include #include #include -#include "paddle/framework/attr_checker.h" -#include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/scope.h" -#include "paddle/utils/Error.h" - namespace paddle { namespace framework { class OperatorBase; -class DeviceContext {}; - -/** - * OpRunContext is the only parameter of Operator's Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * OpRunContext. User should construct it before run the Operator. - */ -class OpRunContext { - public: - OpRunContext(const OperatorBase* op, const std::shared_ptr scope, - const DeviceContext* device_context) - : op_(op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const; - Variable* Output(int index) const; - - public: - const OperatorBase* op_; - const std::shared_ptr scope_; - const DeviceContext* device_context_; -}; - /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -77,7 +55,10 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const = 0; + const platform::DeviceContext& dev_ctx) const = 0; + + protected: + std::string Type() const { return desc_.type(); } public: OpDesc desc_; @@ -86,22 +67,84 @@ class OperatorBase { AttributeMap attrs_; }; +class OpKernel { + public: + /** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ + class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } + + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; + }; + + virtual void Compute(const KernelContext& context) const = 0; + + virtual ~OpKernel() {} +}; + class OperatorWithKernel : public OperatorBase { public: - virtual ~OperatorWithKernel() {} + struct OpKernelKey { + platform::Place place_; - virtual void InferShape(const std::shared_ptr& scope) const {} + OpKernelKey() = default; + OpKernelKey(const platform::DeviceContext& dev_ctx) { + place_ = dev_ctx.GetPlace(); + } + + bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + }; + + struct OpKernelHash { + std::hash hash_; + size_t operator()(const OpKernelKey& key) const { + return hash_(platform::is_gpu_place(key.place_)); + } + }; + + using OpKernelMap = + std::unordered_map, OpKernelHash>; void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const { - OpRunContext op_ctx(this, scope, dev_ctx); - Run(&op_ctx); + const platform::DeviceContext& dev_ctx) const final { + auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } - /// when implement an Op, your should implement this function. - /// this function should be moved to OpKernel later - virtual void Run(const OpRunContext* context) const = 0; + static std::unordered_map& + AllOpKernels() { + static std::unordered_map g_all_op_kernels; + return g_all_op_kernels; + }; }; } // namespace framework } // namespace paddle + +#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \ + struct __op_kernel_register__##type##__ { \ + __op_kernel_register__##type##__() { \ + ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ + key.place_ = PlaceType(); \ + ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ + .reset(new KernelType()); \ + } \ + }; \ + static __op_kernel_register__##type##__ __reg_kernel_##type##__ diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 48808dabb2711936550d04cb49003e87663d3d27..be8c4be2d429648b3c8a708c7f8bdcae3ff2d283 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,17 +19,15 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorWithKernel { +class OperatorTest : public OperatorBase { public: - void Run(const OpRunContext* ctx) const override { - float scale = ctx->op_->GetAttr("scale"); - PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); - PADDLE_ENFORCE(ctx->Output(0) == nullptr, - "Output(1) should not initialized"); - auto output1 = ctx->scope_->CreateVariable("output1"); - PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); - printf("get attr %s = %f\n", "scale", scale); - printf("%s\n", DebugString().c_str()); + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + float scale = GetAttr("scale"); + ASSERT_NEAR(scale, 3.14, 1e-5); + ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); } }; @@ -47,34 +45,79 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } }; -REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) +REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker); -TEST(OperatorBase, DebugString) { +TEST(OperatorBase, all) { OpDesc op_desc; op_desc.set_type("test_operator"); - std::vector inputs = {"IN1", "IN2"}; - for (auto& input : inputs) { - op_desc.add_inputs(input); - } - std::vector outputs = {"OUT1", "OUT2"}; - for (auto& output : outputs) { - op_desc.add_outputs(output); - } + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); float scale = 3.14; attr->set_f(scale); - DeviceContext device_context; + platform::CPUDeviceContext device_context; auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->inputs_, inputs); - ASSERT_EQ(op->outputs_, outputs); ASSERT_EQ(op->GetAttr("scale"), scale); - op->Run(scope, &device_context); + scope->CreateVariable("OUT1"); + op->Run(scope, device_context); + std::cout << op->DebugString() << std::endl; + delete op; } +class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddOutput("output", "output of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("test_operator"); + AddComment("This is test op"); + } +}; + +class OpWithKernelTest : public OperatorWithKernel { + public: + void InferShape(const std::shared_ptr& scope) const override {} +}; + +class CPUKernelTest : public OpKernel { + public: + void Compute(const KernelContext& context) const { + float scale = context.op_.GetAttr("scale"); + ASSERT_NEAR(scale, 3.14, 1e-5); + std::cout << "this is cpu kernel" << std::endl; + std::cout << context.op_.DebugString() << std::endl; + } +}; + +REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker); +REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest); + +TEST(OpKernel, all) { + OpDesc op_desc; + op_desc.set_type("op_with_kernel"); + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.14); + + platform::CPUDeviceContext cpu_device_context; + auto scope = std::make_shared(); + + OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + op->Run(scope, cpu_device_context); + + delete op; +} } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/operators/.clang-format b/paddle/operators/.clang-format deleted file mode 100644 index 29282dc87e2c499988c17d90d47d44cd5cf7f115..0000000000000000000000000000000000000000 --- a/paddle/operators/.clang-format +++ /dev/null @@ -1,5 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: Google -Standard: Cpp11 -... diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h deleted file mode 100644 index d0b7420b4e25d21f718d5e10d62faeb475931a18..0000000000000000000000000000000000000000 --- a/paddle/operators/demo_op.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "paddle/framework/op_registry.h" - -using namespace paddle::framework; - -namespace paddle { -namespace operators { - -class CosineOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddType("cos"); - AddComment("This is cos op"); - } -}; - -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) - -class MyTestOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - auto my_checker = [](int i) { - PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); - }; - AddAttr("test_attr", "a simple test attribute") - .AddCustomChecker(my_checker); - AddType("my_test_op"); - AddComment("This is my_test op"); - } -}; - -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) - -} // namespace operators -} // namespace operators diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 7a198aec6cf12c92cb24a8e560508d06db5e1dcf..358d14f4555e1d046c8e7b91e23d54fb504926e5 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,11 +1,12 @@ add_subdirectory(dynload) -nv_test(cuda_test SRCS cuda_test.cu DEPS dyload_cuda) +nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) + IF(WITH_GPU) - set(GPU_CTX_DEPS dyload_cuda dynamic_loader ) + set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() set(GPU_CTX_DEPS) ENDIF() diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a2dea2ed1e11817c23dd2dc55a578d8fbd21ecb2..960ef0a5955bfe5f7d33b7c8e4524176b0dbfda6 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -1,13 +1,30 @@ -#include +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/device_context.h" namespace paddle { namespace platform { -namespace dynload { -namespace dummy { -// Make DeviceContext A library. -int DUMMY_VAR_FOR_DEV_CTX = 0; -} // namespace dummy -} // namespace dynload +template <> +Eigen::DefaultDevice* DeviceContext::get_eigen_device() { + return reinterpret_cast(this)->eigen_device(); +} + +#ifndef PADDLE_ONLY_CPU +template <> +Eigen::GpuDevice* DeviceContext::get_eigen_device() { + return reinterpret_cast(this)->eigen_device(); +} +#endif + } // namespace platform -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 160eb4e12060b36c4fefba499d4e83b9aab92848..7de07d06bed885d6529a884fb81fedbdaba78f4a 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,8 +19,9 @@ limitations under the License. */ #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU #endif -#include "paddle/platform/place.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include +#include +#include namespace paddle { namespace platform { @@ -31,9 +29,29 @@ namespace platform { class DeviceContext { public: virtual ~DeviceContext() {} + virtual Place GetPlace() const = 0; + + template + DeviceType* get_eigen_device(); }; -class CPUDeviceContext : public DeviceContext {}; +class CPUDeviceContext : public DeviceContext { + public: + Eigen::DefaultDevice* eigen_device() { + if (!eigen_device_) { + eigen_device_.reset(new Eigen::DefaultDevice()); + } + return eigen_device_.get(); + } + + Place GetPlace() const override { + Place retv = CPUPlace(); + return retv; + } + + private: + std::unique_ptr eigen_device_; +}; #ifndef PADDLE_ONLY_CPU @@ -57,8 +75,13 @@ class CUDADeviceContext : public DeviceContext { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); - eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); - eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); + eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); + } + + Place GetPlace() const override { + Place retv = GPUPlace(); + return retv; } void Wait() { @@ -68,7 +91,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice eigen_device() { return *eigen_device_; } + Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -133,10 +156,8 @@ class CUDADeviceContext : public DeviceContext { rand_generator_) == CURAND_STATUS_SUCCESS, "curandDestroyGenerator failed"); } - - delete eigen_stream_; - delete eigen_device_; - + eigen_stream_.reset(); + eigen_device_.reset(); paddle::platform::throw_on_error(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); } @@ -145,8 +166,8 @@ class CUDADeviceContext : public DeviceContext { GPUPlace gpu_place_; cudaStream_t stream_; - Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; + std::unique_ptr eigen_stream_; + std::unique_ptr eigen_device_; cublasHandle_t blas_handle_{nullptr}; @@ -155,6 +176,8 @@ class CUDADeviceContext : public DeviceContext { int random_seed_; curandGenerator_t rand_generator_{nullptr}; }; + #endif + } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 61be4a307dbf073be7dff4564183240834cc7df6..af2ce17fc2238dda62e9888ebe9426edcd55d2bc 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -15,13 +15,26 @@ limitations under the License. */ #include "paddle/platform/device_context.h" #include "gtest/gtest.h" -TEST(CUDADeviceContext, Init) { +using DEVICE_GPU = Eigen::GpuDevice; +TEST(Device, Init) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::DeviceContext* device_context = + new paddle::platform::CUDADeviceContext(i); + Eigen::GpuDevice* gpu_device = + device_context->template get_eigen_device(); + ASSERT_NE(nullptr, gpu_device); + delete device_context; + } +} + +TEST(Device, CUDADeviceContext) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { paddle::platform::CUDADeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = device_context->eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); + Eigen::GpuDevice* gpu_device = device_context->eigen_device(); + ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index 4a8866b3d364542f315978859e96290c6f067f6f..d205ead84598e04eea523be32139959a02e0dd83 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -nv_library(dyload_cuda SRCS cublas.cc cudnn.cc curand.cc) +nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 55aebc59eca50ad33e8a5357c5ca29d4101f754b..f9f87acf15a6b62c343cc0e3db9ebc7e0aabb786 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include @@ -43,4 +44,4 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference); return m.ptr(); -} \ No newline at end of file +}