提交 e7940141 编写于 作者: C chengduoZH

refine seq_concat

上级 437debf4
...@@ -441,7 +441,10 @@ static void InitInferShapeFuncs() { ...@@ -441,7 +441,10 @@ static void InitInferShapeFuncs() {
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) { for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto op_type = kern_pair.first; auto op_type = kern_pair.first;
auto &op_info = info_map.at(op_type); auto it = info_map.find(op_type);
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()( auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{})); "", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
if (op_info.infer_shape_) { // infer_shape has been registered. if (op_info.infer_shape_) { // infer_shape has been registered.
......
...@@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
} }
}; };
......
...@@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T> paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor; concat_grad_functor;
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis), concat_grad_functor(dev_ctx, *out_grad,
&outputs); ctx.MultiInput<framework::Tensor>("X"),
static_cast<int>(axis), &outputs);
} }
} }
}; };
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -24,10 +24,22 @@ namespace detail { ...@@ -24,10 +24,22 @@ namespace detail {
* and passed by `args` * and passed by `args`
*/ */
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
inline T &Ref(T *ptr, ARGS &&... args) { inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE(ptr != nullptr, args...); PADDLE_ENFORCE(ptr != nullptr, args...);
return *ptr; return *ptr;
} }
template <typename T, typename... ARGS>
inline std::vector<std::reference_wrapper<T>> VectorRef(
const std::vector<T*>& vec, ARGS&&... args) {
std::vector<std::reference_wrapper<T>> result;
result.reserve(vec.size());
for (auto* ptr : vec) {
result.emplace_back(Ref(ptr, args...));
}
return result;
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -27,7 +27,7 @@ template <typename T> ...@@ -27,7 +27,7 @@ template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> { class ConcatFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int num = input.size(); int num = input.size();
...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
size_t num = outputs->size(); size_t num = outputs->size();
...@@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
} }
} }
}; };
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
template class ConcatFunctor<platform::CPUDeviceContext, int>; FOR_ALL_TYPES(DEFINE_FUNCTOR);
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -118,7 +119,7 @@ template <typename T> ...@@ -118,7 +119,7 @@ template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> { class ConcatFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int in_num = input.size(); int in_num = input.size();
...@@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
int out_row = 1; int out_row = 1;
...@@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template class ConcatFunctor<platform::CUDADeviceContext, int>; #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>; template class ConcatFunctor<platform::CUDADeviceContext, type>; \
template class ConcatFunctor<platform::CUDADeviceContext, float>; template class ConcatGradFunctor<platform::CUDADeviceContext, type>
template class ConcatFunctor<platform::CUDADeviceContext, double>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int>; FOR_ALL_TYPES(DEFINE_FUNCTOR);
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -37,7 +37,7 @@ template <typename DeviceContext, typename T> ...@@ -37,7 +37,7 @@ template <typename DeviceContext, typename T>
class ConcatFunctor { class ConcatFunctor {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output); framework::Tensor* output);
}; };
...@@ -57,10 +57,21 @@ template <typename DeviceContext, typename T> ...@@ -57,10 +57,21 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs); int axis, std::vector<framework::Tensor*>* outputs);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_concat_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SequenceConcatOp : public framework::OperatorWithKernel { class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void Make() override {
AddInput("X", "The inputs of sequence concat op").AsDuplicable();
void InferShape(framework::InferShapeContext* ctx) const override { AddOutput("Out", "The output of sequence concat op");
PADDLE_ENFORCE(ctx->HasInputs("X"), AddComment(
"Inputs(X) of SequenceConcatOp should not be null."); "Sequence Concat Op\n"
PADDLE_ENFORCE(ctx->HasOutput("Out"), "It will concat LoD tensors by its sequence information.\n"
"Output(Out) of SequenceConcatOp should not be null."); "For example:\n"
const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level")); " LoD of X1 = [0, 3, 7]\n"
const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); " LoD of X2 = [0, 7, 9]\n"
PADDLE_ENFORCE(level == 0UL || level == 1UL, " Result LoD is [0, (3+7), (7+9)]\n"
"The sequence_concat operator only accepts sequence " " i.e.[0, 10, 16]\n");
"or a nested sequence as its input.");
auto ins_dims = ctx->GetInputsDim("X");
framework::DDim out_dims = ins_dims[0];
const size_t n = ins_dims.size();
for (size_t i = 1; i < n; ++i) {
out_dims[axis] += ins_dims[i][axis];
}
ctx->SetOutputDim("Out", out_dims);
} }
}; };
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { class SeqConcatShapeInferer : public framework::InferShapeBase {
public: public:
void Make() override { void operator()(framework::InferShapeContext *context) const override {
AddInput("X", try {
"(LodTensorArray) Input is a vector of LoDTensor, " PADDLE_ENFORCE(context->HasInputs("X"));
"each of which is a variable-length sequence or nested sequence.") PADDLE_ENFORCE(context->HasOutput("Out"));
.AsDuplicable();
AddOutput("Out", auto x_dims = context->GetInputsDim("X");
"(LoDTensor), Variable-length output of " int64_t batch_size = 0;
"sequence_concat Op."); int64_t feature_size = 0;
AddAttr<int>("axis", std::vector<int64_t> out_dims;
"(int, default 0) " for (auto &x_dim : x_dims) {
"The axis along which the inputs will be joined. " if (out_dims.empty()) {
"If axis is 0, the inputs will be joined with LoD index.") out_dims = framework::vectorize(x_dim);
.SetDefault(0); }
AddAttr<int>("level", batch_size += x_dim[0];
"(int, default 0) " if (feature_size == 0) {
"The level at which the inputs will be joined. " feature_size = framework::product(x_dim) / x_dim[0];
"If the level is 0, the inputs will be joined at the nested " } else {
"sequence level. " PADDLE_ENFORCE_EQ(
"If the level is 1, the inputs will be joined at the " feature_size, framework::product(x_dim) / x_dim[0],
"sequence level. " "Inputs of sequence concat must have same feature size");
"The level should be less than the level number of inputs.") }
.SetDefault(0); }
AddComment(R"DOC( if (batch_size < 0) {
The sequence_concat operator concatenates multiple LoDTensors. batch_size = -1; // Normalize batch size for compile time.
It only supports sequence (LoD Tensor with level number is 1) }
or a nested sequence (LoD tensor with level number is 2) as its input. out_dims[0] = batch_size;
- Case1: context->SetOutputDim("Out", framework::make_ddim(out_dims));
If the axis is other than 0(here, axis is 1 and level is 1), if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
each input should have the same LoD information and the LoD // in Kernel.
information of the output keeps the same as the input. context->ShareLoD("X", "Out");
}
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) } catch (...) {
LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4) PADDLE_THROW("Unknown error");
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) }
- Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute.
The LoD information of level-1 should be same.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,2,4}, {0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,2,4}, {0,2,5,8,11}}; Dims(Out) = (11,3,4)
- Case3:
If the axis is 0(here, level is 1).
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,4}, {0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,5,8}, {0,1,2,3,5,7,8,9,11}}; Dims(Out) = (11,3,4)
- Case4:
If the LoD number is 1, axis is 0, level is 0
LoD(x0) = {{0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,2,5,8,11}}; Dims(Out) = (11,3,4)
NOTE: The levels of all the inputs should be the same.
)DOC");
} }
}; };
class SequenceConcatGradOp : public framework::OperatorWithKernel { class SeqConcatGradShapeInferer : public framework::InferShapeBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *context) const override {
context->SetOutputsDim(framework::GradVarName("X"),
void InferShape(framework::InferShapeContext* ctx) const override { context->GetInputsDim("X"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, ops::SequenceConcatOp,
ops::SequenceConcatOpMaker, REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
paddle::framework::DefaultGradOpDescMaker< op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
false> /* set false to disable empty grad */); paddle::framework::DefaultGradOpDescMaker<false>);
REGISTER_OPERATOR(sequence_concat_grad, ops::SequenceConcatGradOp); template <typename T>
REGISTER_OP_CPU_KERNEL( using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
sequence_concat, REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>); REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
REGISTER_OP_CPU_KERNEL( op::SeqConcatGradShapeInferer);
sequence_concat_grad, template <typename T>
ops::SequenceConcatGradOpKernel<paddle::platform::CPUDeviceContext, float>); using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_concat_op.h"
namespace ops = paddle::operators; template <typename T>
REGISTER_OP_CUDA_KERNEL( using Kernel =
sequence_concat, paddle::operators::SeqConcatKernel<paddle::platform::CUDADeviceContext, T>;
ops::SequenceConcatOpKernel<paddle::platform::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
REGISTER_OP_CUDA_KERNEL(sequence_concat_grad, template <typename T>
ops::SequenceConcatGradOpKernel< using GradKernel =
paddle::platform::CUDADeviceContext, float>); paddle::operators::SeqConcatGradKernel<paddle::platform::CUDADeviceContext,
T>;
REGISTER_OP_CUDA_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; namespace detail {
using LoDTensor = framework::LoDTensor; template <typename Container>
using LoD = framework::LoD; inline framework::LoD ConcatLoD(const Container &xs,
std::vector<framework::Tensor> *xs_in_order) {
template <typename T> std::vector<size_t> result;
LoD ConcatLoD(const std::vector<const T*> ins, const size_t level) { result.resize(xs[0].get().lod()[0].size());
auto out_lod = ins[0]->lod();
auto numLevels = ins[0]->NumLevels(); for (size_t i = 1; i < result.size(); ++i) {
const size_t n = ins.size(); size_t sum = 0;
const size_t level_idx = ins[0]->NumLevels() - 1 - level; for (size_t j = 0; j < xs.size(); ++j) {
for (size_t i = 1; i < n; ++i) { auto &x_lod = xs[j].get().lod()[0];
for (size_t j = 0; j < ins[i]->lod()[level_idx].size(); ++j) { const framework::Tensor &tensor = xs[j].get();
out_lod[level_idx][j] += ins[i]->lod()[level_idx][j]; xs_in_order->emplace_back(tensor.Slice(x_lod[i - 1], x_lod[i]));
sum += x_lod[i];
} }
result[i] = sum;
} }
framework::LoD lod;
for (size_t i = level_idx; i < numLevels - 1; ++i) { lod.emplace_back(result);
size_t lod_len = 1; return lod;
for (size_t j = 0; j < n; ++j) {
lod_len += ins[j]->lod()[i + 1].size() - 1;
}
out_lod[i + 1].clear();
out_lod[i + 1].resize(lod_len);
size_t idx = 1;
for (size_t j = 0; j < ins[0]->lod()[i].size() - 1; ++j) {
for (size_t k = 0; k < n; ++k) {
for (size_t m = ins[k]->lod()[i][j]; m < ins[k]->lod()[i][j + 1]; ++m) {
out_lod[i + 1][idx] = out_lod[i + 1][idx - 1] +
ins[k]->lod()[i + 1][m + 1] -
ins[k]->lod()[i + 1][m];
idx++;
}
}
}
}
return out_lod;
} }
} // namespace detail
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceConcatOpKernel : public framework::OpKernel<T> { class SeqConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &context) const override {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto xs = detail::VectorRef(context.MultiInput<framework::LoDTensor>("X"),
auto* out = ctx.Output<LoDTensor>("Out"); "Cannot find multiple input X");
const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis")); auto &out = detail::Ref(context.Output<framework::LoDTensor>("Out"),
const size_t level = static_cast<size_t>(ctx.Attr<int>("level")); "Cannot find output");
const size_t n = ins.size();
size_t lod_size = 0;
for (size_t i = 1; i < n; ++i) { for (auto &x : xs) {
PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), ins[i]->NumLevels(), if (lod_size == 0) {
"The levels of all the input LoDTensors " lod_size = x.get().lod()[0].size();
"should be the same."); } else {
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), ins[i]->dims().size(), PADDLE_ENFORCE_EQ(
"The dimension size of all the input LoDTensors " lod_size, x.get().lod()[0].size(),
"should be the same."); "The number of sequence must be same between each input");
const size_t dims_size = ins[i]->dims().size();
for (size_t j = 0; j < dims_size; ++j) {
if (j == axis) continue;
PADDLE_ENFORCE_EQ(ins[0]->dims()[j], ins[i]->dims()[j],
"Except for the dimension of the specified "
"axis along which all the inputs are concatenated, "
"dimensions of all the other axises of the input "
"LoDTensors should be the same.");
}
}
PADDLE_ENFORCE_GT(ins[0]->NumLevels(), level,
"The levels of all the input LoDTensors "
"should be greater than the specify level");
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = ins[0]->lod();
if (axis == 0) {
out_lod = ConcatLoD<LoDTensor>(ins, level);
}
out->set_lod(out_lod);
const size_t level_idx = out_lod.size() - level - 1;
auto out_lod_level = framework::ToAbsOffset(out_lod)[level_idx];
for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_t = out->Slice(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_stride = framework::stride(out_t.dims());
size_t offset = 0;
for (size_t j = 0; j < n; ++j) {
auto in_lod_level = framework::ToAbsOffset(ins[j]->lod())[level_idx];
auto in_stride = framework::stride(ins[j]->dims());
Tensor in_t = ins[j]->Slice(static_cast<int>(in_lod_level[i]),
static_cast<int>(in_lod_level[i + 1]));
size_t axis_dim = in_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), in_stride,
in_t.dims(), out_stride, out_t.data<T>() + offset);
offset += axis_dim * in_stride[axis];
} }
} }
PADDLE_ENFORCE_NE(lod_size, 0, "Each input must have sequence information");
std::vector<framework::Tensor> x_in_order;
out.set_lod(detail::ConcatLoD(xs, &x_in_order));
out.mutable_data<T>(context.GetPlace());
math::ConcatFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), x_in_order, 0,
&out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceConcatGradOpKernel : public framework::OpKernel<T> { class SeqConcatGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &context) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X"); auto xs = context.MultiInput<framework::LoDTensor>("X");
auto* out_grad = auto dxs =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")); context.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
auto x_grads = PADDLE_ENFORCE_EQ(xs.size(), dxs.size());
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X")); for (size_t i = 0; i < dxs.size(); ++i) {
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis")); if (dxs[i] != nullptr) {
size_t level = static_cast<size_t>(ctx.Attr<int>("level")); dxs[i]->set_lod(xs[i]->lod());
const size_t n = x_grads.size(); dxs[i]->mutable_data<T>(context.GetPlace());
}
// Set Grad(X) LoD as X
for (size_t i = 0; i < n; i++) {
x_grads[i]->set_lod(ins[i]->lod());
x_grads[i]->mutable_data<T>(ctx.GetPlace());
} }
auto out_lod = ins[0]->lod(); std::vector<framework::Tensor> sliced_x;
if (axis == 0UL) { std::vector<boost::variant<boost::blank, framework::Tensor>> sliced_dx;
out_lod = ConcatLoD<LoDTensor>(ins, level);
for (size_t i = 1; i < xs[0]->lod()[0].size(); ++i) {
for (size_t j = 0; j < xs.size(); ++j) {
const framework::LoDTensor *x = xs[j];
framework::LoDTensor *dx = dxs[j];
auto &x_lod = x->lod()[0];
sliced_x.emplace_back(x->Slice(x_lod[i - 1], x_lod[i]));
if (dx != nullptr) {
sliced_dx.emplace_back(dx->Slice(x_lod[i - 1], x_lod[i]));
} else {
sliced_dx.emplace_back(boost::blank());
}
}
} }
const size_t level_idx = out_lod.size() - level - 1;
auto out_lod_level = framework::ToAbsOffset(out_lod)[level_idx];
for (size_t i = 0; i < out_lod_level.size() - 1; ++i) { math::ConcatGradFunctor<DeviceContext, T> functor;
Tensor out_grad_t = std::vector<const framework::Tensor *> sliced_x_ptr;
out_grad->Slice(static_cast<int>(out_lod_level[i]), std::vector<framework::Tensor *> sliced_dx_ptr;
static_cast<int>(out_lod_level[i + 1])); for (auto &x : sliced_x) {
auto out_grad_stride = framework::stride(out_grad_t.dims()); sliced_x_ptr.emplace_back(&x);
size_t offset = 0; }
for (size_t j = 0; j < n; ++j) { for (auto &dx : sliced_dx) {
auto x_grad_lod_level = try {
framework::ToAbsOffset(x_grads[j]->lod())[level_idx]; sliced_dx_ptr.emplace_back(&boost::get<framework::Tensor>(dx));
auto x_grad_stride = framework::stride(x_grads[j]->dims()); } catch (boost::bad_get &) {
Tensor x_grad_t = sliced_dx_ptr.emplace_back(nullptr);
x_grads[j]->Slice(static_cast<int>(x_grad_lod_level[i]),
static_cast<int>(x_grad_lod_level[i + 1]));
size_t axis_dim = x_grad_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>() + offset,
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
offset += axis_dim * out_grad_stride[axis];
} }
} }
functor(context.template device_context<DeviceContext>(),
detail::Ref(
context.Input<framework::Tensor>(framework::GradVarName("Out")),
"Sequence Concat OG must be set"),
sliced_x_ptr, 0, &sliced_dx_ptr);
} }
}; };
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestSequenceConcat(OpTest):
def setUp(self):
x1 = np.random.random(size=(10, 80))
lod1 = [7, 3]
x2 = np.random.random(size=(20, 80))
lod2 = [12, 8]
out = np.concatenate((x1[0:lod1[0]], x2[0:lod2[0]], x1[lod1[0]:],
x2[lod2[0]:]))
out_lod = [19, 11]
self.op_type = "sequence_concat"
self.inputs = {'X': [("x1", (x1, [lod1])), ("x2", (x2, [lod2]))]}
self.outputs = {"Out": (out, [out_lod])}
def test_output(self):
self.check_output(1e-3)
def test_dx(self):
self.check_grad(inputs_to_check=['x1', 'x2'], output_names="Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册