提交 0d751917 编写于 作者: C chengduoZH

speed up lod_tensor to array and array to lod_tensor

上级 437debf4
......@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/operators/math/concat.h>
#include <numeric>
#include "paddle/fluid/framework/lod_rank_table.h"
......@@ -24,6 +25,50 @@ namespace operators {
using LoD = framework::LoD;
class ArrayToLoDFunctor;
template <typename DeviceContext>
struct ArrayToLoDFunctorImpl {
const ArrayToLoDFunctor *prev_functor_;
DeviceContext *dev_ctx_;
template <typename T>
void apply();
};
struct ArrayToLoDFunctor : public boost::static_visitor<void> {
std::vector<framework::Tensor> in;
mutable framework::Tensor *out;
template <typename Place>
void operator()(Place place) const {
auto &pool = platform::DeviceContextPool::Instance();
if (std::is_same<Place, platform::CPUPlace>::value) {
Apply(static_cast<platform::CPUDeviceContext *>(pool.Get(place)));
} else {
#ifdef PADDLE_WITH_CUDA
Apply(static_cast<platform::CUDADeviceContext *>(pool.Get(place)));
#else
PADDLE_THROW("Fluid is not compiled with CUDA");
#endif
}
}
template <typename DeviceContext>
void Apply(DeviceContext *dev_ctx) const {
ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor);
}
};
template <typename DeviceContext>
template <typename T>
void ArrayToLoDFunctorImpl<DeviceContext>::apply() {
math::ConcatFunctor<DeviceContext, T> func;
func(*dev_ctx_, prev_functor_->in, 0, prev_functor_->out);
}
class ArrayToLoDTensorOp : public framework::OperatorBase {
public:
ArrayToLoDTensorOp(const std::string &type,
......@@ -47,14 +92,18 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
int rank = x[0].dims().size();
platform::Place place = x[0].place();
std::type_index data_type = x[0].type();
framework::DDim ins_dims = framework::slice_ddim(x[0].dims(), 1, rank);
int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank)
: framework::make_ddim({0});
for (size_t i = 1; i < x.size(); ++i) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x[i].dims(), 1, rank), ins_dims,
auto ins_i_dims = rank > 1 ? framework::slice_ddim(x[i].dims(), 1, rank)
: framework::make_ddim({0});
PADDLE_ENFORCE_EQ(ins_i_dims, ins_dims,
"The dimension of the %zu'th element in LoDTensorArray "
"differs from previous ones.",
i);
PADDLE_ENFORCE(platform::places_are_same_class(x[i].place(), place),
PADDLE_ENFORCE(x[i].place() == place,
"The place class of the %zu'th element in LoDTensorArray "
"differs from previous ones.",
i);
......@@ -82,13 +131,14 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
// Build LoDTensor `out`
framework::LoD *out_lod = out->mutable_lod();
out_lod->clear();
size_t out_offset = 0;
auto prefix_lod = rank_table.coarse_lod();
prefix_lod.emplace_back();
auto &cur_level_lod = prefix_lod.back();
cur_level_lod.push_back(0);
ArrayToLoDFunctor functor;
for (size_t idx : table_item_idx) {
cur_level_lod.push_back(cur_level_lod.back() + table_items[idx].length);
PADDLE_ENFORCE_LE(table_items[idx].length, x.size());
for (size_t x_idx = 0; x_idx < table_items[idx].length; ++x_idx) {
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
x[x_idx].lod(), idx, idx + 1, 0);
......@@ -106,17 +156,11 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
if (len == 0) {
continue;
}
auto slice = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::TensorCopy(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice);
out_offset += len;
functor.in.emplace_back(x[x_idx].Slice(start_offset, end_offset));
}
}
functor.out = out;
platform::VisitPlace(place, functor);
out_lod->insert(out_lod->begin(), prefix_lod.begin(), prefix_lod.end());
}
};
......
......@@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
......
......@@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
&outputs);
concat_grad_functor(dev_ctx, *out_grad,
ctx.MultiInput<framework::Tensor>("X"),
static_cast<int>(axis), &outputs);
}
}
};
......
......@@ -11,10 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <map>
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
......@@ -26,6 +29,61 @@ struct CopyRange {
size_t end;
};
struct LoDTensorToArrayFunctor;
template <typename DeviceContext>
struct LoDTensorToArrayFunctorImpl {
const LoDTensorToArrayFunctor *prev_functor_;
DeviceContext *dev_ctx_;
template <typename T>
void apply();
};
struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
std::vector<const framework::Tensor *> ref_inputs_;
mutable std::vector<framework::Tensor *> outputs_;
const framework::Tensor &input_;
explicit LoDTensorToArrayFunctor(const framework::Tensor &input)
: input_(input) {}
void AddOutput(framework::Tensor *t) {
outputs_.emplace_back(t);
ref_inputs_.emplace_back(t);
}
template <typename Place>
void operator()(Place place) const {
auto &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(place);
if (std::is_same<Place, platform::CPUPlace>::value) {
Apply(static_cast<platform::CPUDeviceContext *>(dev_ctx));
} else {
#ifdef PADDLE_WITH_CUDA
Apply(static_cast<platform::CUDADeviceContext *>(dev_ctx));
#else
PADDLE_THROW("Not compiled with cuda");
#endif
}
}
template <typename DeviceContext>
void Apply(DeviceContext *dev_ctx) const {
LoDTensorToArrayFunctorImpl<DeviceContext> func;
func.prev_functor_ = this;
func.dev_ctx_ = dev_ctx;
framework::VisitDataType(framework::ToDataType(input_.type()), func);
}
};
template <typename DeviceContext>
template <typename T>
void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() {
math::ConcatGradFunctor<DeviceContext, T> func;
func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0,
&prev_functor_->outputs_);
}
class LoDTensorToArrayOp : public framework::OperatorBase {
public:
LoDTensorToArrayOp(const std::string &type,
......@@ -72,6 +130,11 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset});
}
}
auto &outputs = *const_cast<framework::Scope &>(scope)
.Var()
->GetMutable<std::map<size_t, framework::Tensor>>();
for (size_t i = 0; i < max_seq_len; ++i) {
auto &ranges = copy_ranges[i];
size_t height = std::accumulate(
......@@ -90,17 +153,16 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
// out[i][offset: offset+len] = x[each_range.begin: each_range.end]
auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len));
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::TensorCopy(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice);
outputs.insert({each_range.begin, slice});
offset += len;
}
}
LoDTensorToArrayFunctor functor(x);
for (auto &out_pair : outputs) {
functor.AddOutput(&out_pair.second);
}
platform::VisitPlace(place, functor);
}
};
......
......@@ -27,7 +27,7 @@ template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
......@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
size_t num = outputs->size();
......@@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
}
}
};
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
template class ConcatFunctor<platform::CPUDeviceContext, int>;
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace operators
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -118,7 +119,7 @@ template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int in_num = input.size();
......@@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
const std::vector<const framework::Tensor*>& ref_inputs,
int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
int o_num = outputs->size();
int out_row = 1;
......@@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
}
};
template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CUDADeviceContext, type>; \
template class ConcatGradFunctor<platform::CUDADeviceContext, type>
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace operators
......
......@@ -37,7 +37,7 @@ template <typename DeviceContext, typename T>
class ConcatFunctor {
public:
void operator()(const DeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output);
};
......@@ -57,10 +57,21 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs);
const std::vector<const framework::Tensor*>& ref_inputs,
int axis, std::vector<framework::Tensor*>* outputs);
};
} // namespace math
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册