diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 1e5c68a1b9d4bbd88f68292a1eedec3d0a7ad097..f30dcc000b71428bcd8cbed64bd143de086cbacd 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -99,7 +99,7 @@ class Tensor { inline Tensor& ShareDataWith(const Tensor& src); /*! Share part of the memory of the two tensors */ - inline Tensor& ShareDataWith(Tensor* src, size_t offset); + inline Tensor& ShareDataWith(const Tensor* src, size_t offset); /** * @brief Return a sub-tensor of the given tensor. @@ -181,19 +181,21 @@ class Tensor { template struct SharedPlaceholderImpl : public Placeholder { - SharedPlaceholderImpl(Place place, uint8_t* data, size_t size, + SharedPlaceholderImpl(Place place, const uint8_t* data, size_t size, std::type_index type) : ptr_(data), place_(place), size_(size), type_(type) {} virtual size_t size() const { return size_; } virtual platform::Place place() const { return place_; } - virtual void* ptr() const { return static_cast(ptr_); } + virtual void* ptr() const { + return const_cast(static_cast(ptr_)); + } virtual std::type_index type() const { return type_; } virtual void set_type(std::type_index type) { type_ = type; } virtual void set_place(platform::Place place) { place_ = place; } /*! the pointer of memory block. */ - uint8_t* ptr_; + const uint8_t* ptr_; /*! the place of memory block. */ platform::Place place_; diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 98d53fd1e7db95ff9e4eb5caf92a14c483cca3cf..a177ef74166f202a205c6d306e84915e0f8f1129 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -162,7 +162,7 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) { return *this; } -inline Tensor& Tensor::ShareDataWith(Tensor* src, size_t offset) { +inline Tensor& Tensor::ShareDataWith(const Tensor* src, size_t offset) { // NOTE: data size is determined by current tensor shape and data type src->check_memory_size(); PADDLE_ENFORCE_EQ(src->type(), this->type(), @@ -170,7 +170,7 @@ inline Tensor& Tensor::ShareDataWith(Tensor* src, size_t offset) { auto place = src->place(); auto type = src->type(); size_t size = src->numel() * SizeOfType(src->type()); - auto* ref = static_cast(src->mutable_data(place)) + offset; + auto* ref = src->data() + offset; if (platform::is_cpu_place(place)) { holder_.reset(new SharedPlaceholderImpl( boost::get(place), ref, size, type)); diff --git a/paddle/fluid/operators/split_byref_op.cc b/paddle/fluid/operators/split_byref_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7413ce3e9ce60ed733bb4d27e9ec205e5f0a7e1b --- /dev/null +++ b/paddle/fluid/operators/split_byref_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/split_byref_op.h" +#include "paddle/fluid/operators/split_op.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + +class SplitByrefOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SplitOp should not be null."); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + "Outputs(Out) of SplitOp should not be empty."); + auto in_dims = ctx->GetInputDim("X"); + auto outs_names = ctx->Outputs("Out"); + size_t num = static_cast(ctx->Attrs().Get("num")); + std::vector sections = static_cast>( + ctx->Attrs().Get>("sections")); + const size_t outs_number = outs_names.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); + + if (num > 0) { + int64_t in_axis_dim = in_dims[0]; + PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, + "tensor split does not result" + " in an equal division"); + size_t out_axis_dim = in_axis_dim / num; + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[0] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + PADDLE_ENFORCE_EQ(sections.size(), outs_number, + "tensor split sections size" + "should be equal to output size."); + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[0] = sections[i]; + outs_dims.push_back(dim); + } + } + ctx->SetOutputsDim("Out", outs_dims); + } +}; + +class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitByrefOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor) Input tensor of the split operator."); + AddOutput("Out", "(Tensor) Output tensors of the split operator.") + .AsDuplicable(); + AddComment(R"DOC( +SplitByref operator + +Split source tensor to sevaral tensors by axis 0. No copy in this operator +is performed, output tensor shares the same blocks of memory. +)DOC"); + AddAttr>("sections", + "(vector) " + "the length of each output along the " + "specified axis.") + .SetDefault(std::vector{}); + AddAttr("num", + "(int, default 0)" + "Number of sub-tensors. This must evenly divide " + "Input.dims()[axis]") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +// NOTE: concat op default axis must be 0! +USE_CPU_ONLY_OP(concat); + +REGISTER_OPERATOR(split_byref, ops::SplitByrefOp, ops::SplitByrefOpMaker, + ops::SplitGradMaker); +REGISTER_OP_CPU_KERNEL( + split_byref, ops::SplitByrefOpKernel); diff --git a/paddle/fluid/operators/split_byref_op.cu.cc b/paddle/fluid/operators/split_byref_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..1faf4f55dd54a2dc28c19232643563a31850f38b --- /dev/null +++ b/paddle/fluid/operators/split_byref_op.cu.cc @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/split_byref_op.h" +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + split, ops::SplitByrefOpKernel); diff --git a/paddle/fluid/operators/split_byref_op.h b/paddle/fluid/operators/split_byref_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7c3ab1c1b9d9550c63b56056746c6223ce1b9c77 --- /dev/null +++ b/paddle/fluid/operators/split_byref_op.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SplitByrefOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto outs = ctx.MultiOutput("Out"); + auto in_stride = framework::stride_numel(in->dims()); + auto place = ctx.GetPlace(); + + size_t input_offset = 0; + for (size_t i = 0; i < outs.size(); ++i) { + // NOTE: no need to call mutable_data here to allocate memory. + auto* out = outs[i]; + out->ShareDataWith(in, input_offset); + input_offset += out->numel() * framework::SizeOfType(out->type()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index e745509ec8c1f2ec305d7d4aabfdd43d847124b5..a4398df36bcc2d3b8bbe8949f27f5d6508861d95 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -108,21 +108,6 @@ Example: } }; -class SplitGradMaker : public framework::SingleGradOpDescMaker { - public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - auto op = new framework::OpDesc(); - op->SetType("concat"); - op->SetInput("X", OutputGrad("Out")); - op->SetOutput("Out", InputGrad("X")); - op->SetAttrMap(Attrs()); - return std::unique_ptr(op); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/split_op.h b/paddle/fluid/operators/split_op.h index e2c41f44ab3ea3c42837974dae749278c9356ba5..f0c417c70521b1bb3816f884d6ab7393473999e4 100644 --- a/paddle/fluid/operators/split_op.h +++ b/paddle/fluid/operators/split_op.h @@ -44,5 +44,20 @@ class SplitOpKernel : public framework::OpKernel { } }; +class SplitGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto op = new framework::OpDesc(); + op->SetType("concat"); + op->SetInput("X", OutputGrad("Out")); + op->SetOutput("Out", InputGrad("X")); + op->SetAttrMap(Attrs()); + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle