提交 04c559e3 编写于 作者: T typhoonzero

wip split byref op

上级 f86d35a2
...@@ -99,7 +99,7 @@ class Tensor { ...@@ -99,7 +99,7 @@ class Tensor {
inline Tensor& ShareDataWith(const Tensor& src); inline Tensor& ShareDataWith(const Tensor& src);
/*! Share part of the memory of the two tensors */ /*! Share part of the memory of the two tensors */
inline Tensor& ShareDataWith(Tensor* src, size_t offset); inline Tensor& ShareDataWith(const Tensor* src, size_t offset);
/** /**
* @brief Return a sub-tensor of the given tensor. * @brief Return a sub-tensor of the given tensor.
...@@ -181,19 +181,21 @@ class Tensor { ...@@ -181,19 +181,21 @@ class Tensor {
template <typename Place> template <typename Place>
struct SharedPlaceholderImpl : public Placeholder { struct SharedPlaceholderImpl : public Placeholder {
SharedPlaceholderImpl(Place place, uint8_t* data, size_t size, SharedPlaceholderImpl(Place place, const uint8_t* data, size_t size,
std::type_index type) std::type_index type)
: ptr_(data), place_(place), size_(size), type_(type) {} : ptr_(data), place_(place), size_(size), type_(type) {}
virtual size_t size() const { return size_; } virtual size_t size() const { return size_; }
virtual platform::Place place() const { return place_; } virtual platform::Place place() const { return place_; }
virtual void* ptr() const { return static_cast<void*>(ptr_); } virtual void* ptr() const {
return const_cast<void*>(static_cast<const void*>(ptr_));
}
virtual std::type_index type() const { return type_; } virtual std::type_index type() const { return type_; }
virtual void set_type(std::type_index type) { type_ = type; } virtual void set_type(std::type_index type) { type_ = type; }
virtual void set_place(platform::Place place) { place_ = place; } virtual void set_place(platform::Place place) { place_ = place; }
/*! the pointer of memory block. */ /*! the pointer of memory block. */
uint8_t* ptr_; const uint8_t* ptr_;
/*! the place of memory block. */ /*! the place of memory block. */
platform::Place place_; platform::Place place_;
......
...@@ -162,7 +162,7 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) { ...@@ -162,7 +162,7 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
return *this; return *this;
} }
inline Tensor& Tensor::ShareDataWith(Tensor* src, size_t offset) { inline Tensor& Tensor::ShareDataWith(const Tensor* src, size_t offset) {
// NOTE: data size is determined by current tensor shape and data type // NOTE: data size is determined by current tensor shape and data type
src->check_memory_size(); src->check_memory_size();
PADDLE_ENFORCE_EQ(src->type(), this->type(), PADDLE_ENFORCE_EQ(src->type(), this->type(),
...@@ -170,7 +170,7 @@ inline Tensor& Tensor::ShareDataWith(Tensor* src, size_t offset) { ...@@ -170,7 +170,7 @@ inline Tensor& Tensor::ShareDataWith(Tensor* src, size_t offset) {
auto place = src->place(); auto place = src->place();
auto type = src->type(); auto type = src->type();
size_t size = src->numel() * SizeOfType(src->type()); size_t size = src->numel() * SizeOfType(src->type());
auto* ref = static_cast<uint8_t*>(src->mutable_data(place)) + offset; auto* ref = src->data<uint8_t>() + offset;
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
holder_.reset(new SharedPlaceholderImpl<platform::CPUPlace>( holder_.reset(new SharedPlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), ref, size, type)); boost::get<platform::CPUPlace>(place), ref, size, type));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/split_byref_op.h"
#include "paddle/fluid/operators/split_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class SplitByrefOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SplitOp should not be null.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of SplitOp should not be empty.");
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
std::vector<int> sections = static_cast<std::vector<int>>(
ctx->Attrs().Get<std::vector<int>>("sections"));
const size_t outs_number = outs_names.size();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
int64_t in_axis_dim = in_dims[0];
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
"tensor split does not result"
" in an equal division");
size_t out_axis_dim = in_axis_dim / num;
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[0] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(sections.size(), outs_number,
"tensor split sections size"
"should be equal to output size.");
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[0] = sections[i];
outs_dims.push_back(dim);
}
}
ctx->SetOutputsDim("Out", outs_dims);
}
};
class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitByrefOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable();
AddComment(R"DOC(
SplitByref operator
Split source tensor to sevaral tensors by axis 0. No copy in this operator
is performed, output tensor shares the same blocks of memory.
)DOC");
AddAttr<std::vector<int>>("sections",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault(std::vector<int>{});
AddAttr<int>("num",
"(int, default 0)"
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// NOTE: concat op default axis must be 0!
USE_CPU_ONLY_OP(concat);
REGISTER_OPERATOR(split_byref, ops::SplitByrefOp, ops::SplitByrefOpMaker,
ops::SplitGradMaker);
REGISTER_OP_CPU_KERNEL(
split_byref, ops::SplitByrefOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/split_byref_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
split, ops::SplitByrefOpKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SplitByrefOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride_numel(in->dims());
auto place = ctx.GetPlace();
size_t input_offset = 0;
for (size_t i = 0; i < outs.size(); ++i) {
// NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i];
out->ShareDataWith(in, input_offset);
input_offset += out->numel() * framework::SizeOfType(out->type());
}
}
};
} // namespace operators
} // namespace paddle
...@@ -108,21 +108,6 @@ Example: ...@@ -108,21 +108,6 @@ Example:
} }
}; };
class SplitGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto op = new framework::OpDesc();
op->SetType("concat");
op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -44,5 +44,20 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -44,5 +44,20 @@ class SplitOpKernel : public framework::OpKernel<T> {
} }
}; };
class SplitGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto op = new framework::OpDesc();
op->SetType("concat");
op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册