diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 2c1eb7521d896e6d67021e4d020f274bb32d123c..58a35564f83928ee0bdaad63b154ed57d8d8a735 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,7 +15,6 @@ if(Boost_FOUND) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) - add_subdirectory(operators) add_subdirectory(pybind) endif() diff --git a/paddle/framework/dim.h b/paddle/framework/dim.h index bcde291d12d429a3f2cd41fa6d0ee606c7c9c92f..883fdc55eb929ebc51e8ae05938e9d07374406ce 100644 --- a/paddle/framework/dim.h +++ b/paddle/framework/dim.h @@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { return ((0 <= idx.head) && (idx.head < size.head)); } -/** - * \brief Check if a size and a stride create a Fortran order contiguous - * block of memory. - */ -template -HOST bool contiguous(const Dim& size, const Dim& stride, int mul = 1) { - if (product(size) == 0) return true; - int contiguous_stride = get<0>(size) == 1 ? 0 : mul; - return (get<0>(stride) == contiguous_stride && - contiguous(size.tail, stride.tail, mul * get<0>(size))); -} - -///\cond HIDDEN -// Base case of contiguous, check the nth stride is the size of -// the prefix multiply of n-1 dims. -template <> -inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) { - if (get<0>(size) == 0) return true; - int contiguous_stride = get<0>(size) == 1 ? 0 : mul; - return get<0>(stride) == contiguous_stride; -} -///\endcond - /** * \brief Compute exclusive prefix-multiply of a Dim. */ @@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { } ///\endcond -/** - * \brief Calculate strides of a contiguous array of the given size - * - * Sets the stride for any dimension with an extent of 1 to 0. - * \param size Dim object containing the size of the array. - * \param base The base stride to use. - * \return Dim object the same size as \p size with the strides. - */ -template -HOSTDEVICE Dim contiguous_strides(const Dim& size, int base = 1) { - int stride = size.head == 1 ? 0 : base; - return Dim(stride, contiguous_strides(size.tail, base * size.head)); -} - -///\cond HIDDEN - -// Base case of contiguous_strides -template <> -HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) { - int stride = size.head == 1 ? 0 : base; - return Dim<1>(stride); -} - -///\endcond - /** * Add two dimensions together */ diff --git a/paddle/framework/dim_test.cu b/paddle/framework/dim_test.cu index 809bf04826637195425a32c054c94e00ef940df9..05217415196f3ec3ce9b5de7cb2f82c9de960ba7 100644 --- a/paddle/framework/dim_test.cu +++ b/paddle/framework/dim_test.cu @@ -58,24 +58,6 @@ TEST(Dim, Equality) { EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<2>(c), 12); - // contiguous_strides - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10)); - EXPECT_EQ(paddle::framework::get<0>(c), 1); - EXPECT_EQ(paddle::framework::get<1>(c), 0); - EXPECT_EQ(paddle::framework::get<2>(c), 10); - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1)); - EXPECT_EQ(paddle::framework::get<0>(c), 1); - EXPECT_EQ(paddle::framework::get<1>(c), 10); - EXPECT_EQ(paddle::framework::get<2>(c), 0); - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10)); - EXPECT_EQ(paddle::framework::get<0>(c), 0); - EXPECT_EQ(paddle::framework::get<1>(c), 1); - EXPECT_EQ(paddle::framework::get<2>(c), 10); - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4)); - EXPECT_EQ(paddle::framework::get<0>(c), 1); - EXPECT_EQ(paddle::framework::get<1>(c), 2); - EXPECT_EQ(paddle::framework::get<2>(c), 6); - // generate from an index auto size = paddle::framework::make_dim(4, 5, 2); c = paddle::framework::Dim<3>(14, size); @@ -101,16 +83,6 @@ TEST(Dim, Bool) { EXPECT_TRUE(a == a); EXPECT_FALSE(a == b); EXPECT_TRUE(a == c); - - // contiguous check - int x = 4, y = 5, z = 2; - paddle::framework::Dim<3> sizef(x, y, z); - paddle::framework::Dim<3> stridea(1, x, x*y); - paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y); - paddle::framework::Dim<3> stridec(1, x, 2*x*y); - EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea)); - EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb)); - EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec)); } TEST(Dim, Print) { diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index c4baafc2aebc8d009a388635bbab180d86a4b914..f5d45a80bb8e9fa095e7d6adc6370918b3f87f5a 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,17 +1,15 @@ #include "paddle/framework/op_registry.h" #include -#include "paddle/framework/operator.h" -#include "paddle/operators/demo_op.h" using namespace paddle::framework; namespace paddle { namespace framework { -class CosineOp : public OperatorWithKernel { +class CosineOp : public OperatorBase { public: - void Run(const OpRunContext* context) const override { - printf("%s\n", DebugString().c_str()); - } + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + void InferShape(const std::shared_ptr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OperatorWithKernel { +class MyTestOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + public: - void Run(const OpRunContext* ctx) const override { - printf("%s\n", DebugString().c_str()); - printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); - } }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); float scale_get = op->GetAttr("scale"); ASSERT_EQ(scale_get, scale); } @@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); } @@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(4); paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto dev_ctx = DeviceContext(); + paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); - op->Run(scope, &dev_ctx); + op->Run(scope, dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3db3706e47dfab49f54b3f1f9f2e41c53fc3f298..8f7adff8b3982e91a3d7f6d598cd62d5005d5f17 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const { return ss.str(); } -const Variable* OpRunContext::Input(int index) const { - return scope_->GetVariable(op_->inputs_[index]); -} - -Variable* OpRunContext::Output(int index) const { - return scope_->GetVariable(op_->outputs_[index]); -} - } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6570d5869814a198195968606d041055e847ca08..0ce422e007c39ce0c3f5f7a89650cc211919ea8f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,44 +14,22 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include +#include +#include #include #include #include #include -#include "paddle/framework/attr_checker.h" -#include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/scope.h" -#include "paddle/utils/Error.h" - namespace paddle { namespace framework { class OperatorBase; -class DeviceContext {}; - -/** - * OpRunContext is the only parameter of Operator's Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * OpRunContext. User should construct it before run the Operator. - */ -class OpRunContext { - public: - OpRunContext(const OperatorBase* op, const std::shared_ptr scope, - const DeviceContext* device_context) - : op_(op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const; - Variable* Output(int index) const; - - public: - const OperatorBase* op_; - const std::shared_ptr scope_; - const DeviceContext* device_context_; -}; - /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -77,7 +55,10 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const = 0; + const platform::DeviceContext& dev_ctx) const = 0; + + protected: + std::string Type() const { return desc_.type(); } public: OpDesc desc_; @@ -86,22 +67,84 @@ class OperatorBase { AttributeMap attrs_; }; +class OpKernel { + public: + /** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ + class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } + + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; + }; + + virtual void Compute(const KernelContext& context) const = 0; + + virtual ~OpKernel() {} +}; + class OperatorWithKernel : public OperatorBase { public: - virtual ~OperatorWithKernel() {} + struct OpKernelKey { + platform::Place place_; - virtual void InferShape(const std::shared_ptr& scope) const {} + OpKernelKey() = default; + OpKernelKey(const platform::DeviceContext& dev_ctx) { + place_ = dev_ctx.GetPlace(); + } + + bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + }; + + struct OpKernelHash { + std::hash hash_; + size_t operator()(const OpKernelKey& key) const { + return hash_(platform::is_gpu_place(key.place_)); + } + }; + + using OpKernelMap = + std::unordered_map, OpKernelHash>; void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const { - OpRunContext op_ctx(this, scope, dev_ctx); - Run(&op_ctx); + const platform::DeviceContext& dev_ctx) const final { + auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } - /// when implement an Op, your should implement this function. - /// this function should be moved to OpKernel later - virtual void Run(const OpRunContext* context) const = 0; + static std::unordered_map& + AllOpKernels() { + static std::unordered_map g_all_op_kernels; + return g_all_op_kernels; + }; }; } // namespace framework } // namespace paddle + +#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \ + struct __op_kernel_register__##type##__ { \ + __op_kernel_register__##type##__() { \ + ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ + key.place_ = PlaceType(); \ + ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ + .reset(new KernelType()); \ + } \ + }; \ + static __op_kernel_register__##type##__ __reg_kernel_##type##__ diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 48808dabb2711936550d04cb49003e87663d3d27..86f45f108a5d2189894fd59483a84b039a010ab3 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,17 +19,15 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorWithKernel { +class OperatorTest : public OperatorBase { public: - void Run(const OpRunContext* ctx) const override { - float scale = ctx->op_->GetAttr("scale"); - PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); - PADDLE_ENFORCE(ctx->Output(0) == nullptr, - "Output(1) should not initialized"); - auto output1 = ctx->scope_->CreateVariable("output1"); - PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); - printf("get attr %s = %f\n", "scale", scale); - printf("%s\n", DebugString().c_str()); + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + float scale = GetAttr("scale"); + ASSERT_NEAR(scale, 3.14, 1e-5); + ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); } }; @@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) -TEST(OperatorBase, DebugString) { +TEST(OperatorBase, all) { OpDesc op_desc; op_desc.set_type("test_operator"); - std::vector inputs = {"IN1", "IN2"}; - for (auto& input : inputs) { - op_desc.add_inputs(input); - } - std::vector outputs = {"OUT1", "OUT2"}; - for (auto& output : outputs) { - op_desc.add_outputs(output); - } + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); float scale = 3.14; attr->set_f(scale); - DeviceContext device_context; + platform::CPUDeviceContext device_context; auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->inputs_, inputs); - ASSERT_EQ(op->outputs_, outputs); ASSERT_EQ(op->GetAttr("scale"), scale); - op->Run(scope, &device_context); + scope->CreateVariable("OUT1"); + op->Run(scope, device_context); + std::cout << op->DebugString() << std::endl; + delete op; } } // namespace framework diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index ce5d98b04e6b53fcedc4fc4610d9390e64846b2a..a0945e8055625ca4c21ea1c3fa9f27321ca9ba3c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include "paddle/framework/ddim.h" @@ -26,31 +27,65 @@ namespace framework { class Tensor { public: + Tensor() : offset_(0) {} + + explicit Tensor(const DDim& dims) : dims_(dims), offset_(0) {} + template const T* data() const { - PADDLE_ENFORCE(holder_ != nullptr, - "Tensor::data must be called after Tensor::mutable_data."); - return static_cast(holder_->Ptr()); + PADDLE_ENFORCE( + holder_ != nullptr, + "Tenosr has not been initialized. Call Tensor::mutable_data first."); + return reinterpret_cast( + reinterpret_cast(holder_->Ptr()) + offset_); } template ::value>::type* = nullptr> T* mutable_data(DDim dims, paddle::platform::Place place) { + dims_ = dims; if (holder_ == nullptr || !(holder_->Place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->Size() < product(dims) * sizeof(T)) { + || holder_->Size() < product(dims) * sizeof(T) + offset_) { holder_.reset(new PlaceholderImpl(place, product(dims) * sizeof(T))); + offset_ = 0; } - return static_cast(holder_->Ptr()); + return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + offset_); } - template ::value>::type* = nullptr> - T* mutable_data(DDim dims) { - return mutable_data(dims, paddle::platform::get_place()); + void ShareDataFrom(const Tensor& src) { + PADDLE_ENFORCE(src.holder_ != nullptr, + "Can not share data from an uninitialized tensor."); + holder_ = src.holder_; + dims_ = src.dims_; + offset_ = src.offset_; } + Tensor Slice(const int& begin_idx, const int& end_idx) const { + PADDLE_ENFORCE(holder_ != nullptr, + "The sliced tenosr has not been initialized."); + PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], + "Slice index is less than zero or out of bound."); + PADDLE_ENFORCE(begin_idx < end_idx, + "Begin index must be less than end index."); + PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1."); + std::vector d = vectorize(dims_); + int base = 1; + for (size_t i = 1; i < d.size(); ++i) { + base *= d[i]; + } + Tensor dst; + dst.holder_ = holder_; + dst.dims_ = dims_; + dst.dims_[0] = end_idx - begin_idx; + dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize(); + return dst; + } + + DDim dims() const { return dims_; } + private: // Placeholder hides type T, so it doesn't appear as a template // parameter of Variable. @@ -59,6 +94,7 @@ class Tensor { virtual void* Ptr() const = 0; virtual paddle::platform::Place Place() const = 0; virtual size_t Size() const = 0; + virtual size_t TypeSize() const = 0; }; template @@ -85,6 +121,7 @@ class Tensor { virtual void* Ptr() const { return static_cast(ptr_.get()); } virtual size_t Size() const { return size_; } virtual paddle::platform::Place Place() const { return place_; } + virtual size_t TypeSize() const { return sizeof(T); } std::unique_ptr ptr_; paddle::platform::Place place_; // record the place of ptr_. @@ -92,6 +129,8 @@ class Tensor { }; std::shared_ptr holder_; // holds the memory block if allocated. + DDim dims_; + size_t offset_; // marks the begin of tensor data area. }; } // namespace framework diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 727d81f8d72e39ec564c42a48bf7ff64204adfff..f4822838cfbd27656232a23b14f716f2fbe510e0 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -15,15 +15,27 @@ #include #include -TEST(Tensor, ASSERT) { - paddle::framework::Tensor cpu_tensor; +TEST(Tensor, Dims) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor tt(make_ddim({2, 3, 4})); + DDim dims = tt.dims(); + ASSERT_EQ(arity(dims), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(i + 2, dims[i]); + } +} + +TEST(Tensor, DataAssert) { + paddle::framework::Tensor src_tensor; bool caught = false; try { - const double* p __attribute__((unused)) = cpu_tensor.data(); + src_tensor.data(); } catch (paddle::framework::EnforceNotMet err) { caught = true; - std::string msg = "Tensor::data must be called after Tensor::mutable_data."; + std::string msg = + "Tenosr has not been initialized. Call Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { ASSERT_EQ(what[i], msg[i]); @@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) { ASSERT_TRUE(caught); } -/* mutable_data() is not tested at present +/* following tests are not available at present because Memory::Alloc() and Memory::Free() have not been ready. TEST(Tensor, MutableData) { using namespace paddle::framework; using namespace paddle::platform; { - Tensor cpu_tensor; + Tensor src_tensor; float* p1 = nullptr; float* p2 = nullptr; // initialization - p1 = cpu_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); + p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); EXPECT_NE(p1, nullptr); - // set cpu_tensor a new dim with large size + // set src_tensor a new dim with large size // momery is supposed to be re-allocated - p2 = cpu_tensor.mutable_data(make_ddim({3, 4})); + p2 = src_tensor.mutable_data(make_ddim({3, 4}), CPUPlace()); EXPECT_NE(p2, nullptr); EXPECT_NE(p1, p2); - // set cpu_tensor a new dim with same size + // set src_tensor a new dim with same size // momery block is supposed to be unchanged - p1 = cpu_tensor.mutable_data(make_ddim({2, 2, 3})); + p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CPUPlace()); EXPECT_EQ(p1, p2); - // set cpu_tensor a new dim with smaller size + // set src_tensor a new dim with smaller size // momery block is supposed to be unchanged - p2 = cpu_tensor.mutable_data(make_ddim({2, 2})); + p2 = src_tensor.mutable_data(make_ddim({2, 2}), CPUPlace()); EXPECT_EQ(p1, p2); } { - Tensor gpu_tensor; + Tensor src_tensor; float* p1 = nullptr; float* p2 = nullptr; // initialization - p1 = gpu_tensor.mutable_data(make_ddim({1, 2, 3}), GPUPlace()); + p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), GPUPlace()); EXPECT_NE(p1, nullptr); - // set gpu_tensor a new dim with large size + // set src_tensor a new dim with large size // momery is supposed to be re-allocated - p2 = gpu_tensor.mutable_data(make_ddim({3, 4})); + p2 = src_tensor.mutable_data(make_ddim({3, 4}), GPUPlace()); EXPECT_NE(p2, nullptr); EXPECT_NE(p1, p2); - // set gpu_tensor a new dim with same size + // set src_tensor a new dim with same size // momery block is supposed to be unchanged - p1 = gpu_tensor.mutable_data(make_ddim({2, 2, 3})); + p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), GPUPlace()); EXPECT_EQ(p1, p2); - // set gpu_tensor a new dim with smaller size + // set src_tensor a new dim with smaller size // momery block is supposed to be unchanged - p2 = gpu_tensor.mutable_data(make_ddim({2, 2})); + p2 = src_tensor.mutable_data(make_ddim({2, 2}), GPUPlace()); EXPECT_EQ(p1, p2); } } -*/ + +TEST(Tensor, ShareDataFrom) { + using namespace paddle::framework; + using namespace paddle::platform; + { + Tensor src_tensor; + Tensor dst_tensor; + // Try to share data form uninitialized tensor + bool caught = false; + try { + dst_tensor.ShareDataFrom(src_tensor); + } catch (EnforceNotMet err) { + caught = true; + std::string msg = "Can not share data from an uninitialized tensor."; + const char* what = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + ASSERT_TRUE(caught); + + src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace()); + dst_tensor.ShareDataFrom(src_tensor); + ASSERT_EQ(src_tensor.data(), dst_tensor.data()); + } + + { + Tensor src_tensor; + Tensor dst_tensor; + src_tensor.mutable_data(make_ddim({2, 3, 4}), GPUPlace()); + dst_tensor.ShareDataFrom(src_tensor); + ASSERT_EQ(src_tensor.data(), dst_tensor.data()); + } +} + +TEST(Tensor, Slice) { + using namespace paddle::framework; + using namespace paddle::platform; + { + Tensor src_tensor; + src_tensor.mutable_data(make_ddim({5, 3, 4}), CPUPlace()); + Tensor slice_tensor = src_tensor.Slice(1, 3); + DDim slice_dims = slice_tensor.dims(); + ASSERT_EQ(arity(slice_dims), 3); + EXPECT_EQ(slice_dims[0], 2); + EXPECT_EQ(slice_dims[1], 3); + EXPECT_EQ(slice_dims[2], 4); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = reinterpret_cast( + src_tensor.mutable_data(src_tensor.dims(), CPUPlace())); + uintptr_t slice_data_address = + reinterpret_cast(slice_tensor.data()); + uintptr_t slice_mutable_data_address = reinterpret_cast( + slice_tensor.mutable_data(slice_tensor.dims(), CPUPlace())); + EXPECT_EQ(src_data_address, src_mutable_data_address); + EXPECT_EQ(slice_data_address, slice_mutable_data_address); + EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address); + } + + { + Tensor src_tensor; + src_tensor.mutable_data(make_ddim({6, 9}), GPUPlace()); + Tensor slice_tensor = src_tensor.Slice(2, 6); + DDim slice_dims = slice_tensor.dims(); + ASSERT_EQ(arity(slice_dims), 2); + EXPECT_EQ(slice_dims[0], 4); + EXPECT_EQ(slice_dims[1], 9); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = reinterpret_cast( + src_tensor.mutable_data(src_tensor.dims(), GPUPlace())); + uintptr_t slice_data_address = + reinterpret_cast(slice_tensor.data()); + uintptr_t slice_mutable_data_address = reinterpret_cast( + slice_tensor.mutable_data(slice_tensor.dims(), GPUPlace())); + EXPECT_EQ(src_data_address, src_mutable_data_address); + EXPECT_EQ(slice_data_address, slice_mutable_data_address); + EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); + } +} + +*/ \ No newline at end of file diff --git a/paddle/operators/.clang-format b/paddle/operators/.clang-format deleted file mode 100644 index 29282dc87e2c499988c17d90d47d44cd5cf7f115..0000000000000000000000000000000000000000 --- a/paddle/operators/.clang-format +++ /dev/null @@ -1,5 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: Google -Standard: Cpp11 -... diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h deleted file mode 100644 index d0b7420b4e25d21f718d5e10d62faeb475931a18..0000000000000000000000000000000000000000 --- a/paddle/operators/demo_op.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "paddle/framework/op_registry.h" - -using namespace paddle::framework; - -namespace paddle { -namespace operators { - -class CosineOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddType("cos"); - AddComment("This is cos op"); - } -}; - -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) - -class MyTestOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - auto my_checker = [](int i) { - PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); - }; - AddAttr("test_attr", "a simple test attribute") - .AddCustomChecker(my_checker); - AddType("my_test_op"); - AddComment("This is my_test op"); - } -}; - -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) - -} // namespace operators -} // namespace operators diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h index 420159fb2c610f29f60e3b0a11b61e47c13055dc..c38dcd5a6158df92da2f7e19599ab3247df604f3 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device_context.h @@ -23,15 +23,13 @@ limitations under the License. */ #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" -using DEVICE_GPU = Eigen::GpuDevice; - namespace paddle { namespace platform { class CUDADeviceContext; template <> -DEVICE_GPU DeviceContext::get_eigen_device() { +Eigen::GpuDevice DeviceContext::get_eigen_device() { return static_cast(this)->eigen_handle(); } @@ -59,6 +57,11 @@ class CUDADeviceContext : public DeviceContext { eigen_device_ = new Eigen::GpuDevice(eigen_stream_); } + Place GetPlace() const override { + Place retv = GPUPlace(); + return retv; + } + void Wait() { paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), "cudaStreamSynchronize failed"); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 11a05702cd9876c98e730ccaaede2fb04696254e..d2a516999170efd5f5679960670db5b534de9d2f 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/enforce.h" +#include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" -using DEVICE_CPU = Eigen::DefaultDevice; - namespace paddle { namespace platform { @@ -28,10 +28,12 @@ class DeviceContext { template DeviceType get_eigen_device(); + + virtual Place GetPlace() const = 0; }; template <> -DEVICE_CPU DeviceContext::get_eigen_device() { +Eigen::DefaultDevice DeviceContext::get_eigen_device() { return static_cast(this)->eigen_handle(); } @@ -44,9 +46,13 @@ class CPUDeviceContext : public DeviceContext { return *eigen_handle_; } + Place GetPlace() const override { + Place retv = CPUPlace(); + return retv; + } + private: Eigen::DefaultDevice* eigen_handle_{nullptr}; }; - } // namespace platform } // namespace paddle