提交 e7940141 编写于 作者: C chengduoZH

refine seq_concat

上级 437debf4
......@@ -441,7 +441,10 @@ static void InitInferShapeFuncs() {
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto op_type = kern_pair.first;
auto &op_info = info_map.at(op_type);
auto it = info_map.find(op_type);
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
if (op_info.infer_shape_) { // infer_shape has been registered.
......
......@@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
......
......@@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
&outputs);
concat_grad_functor(dev_ctx, *out_grad,
ctx.MultiInput<framework::Tensor>("X"),
static_cast<int>(axis), &outputs);
}
}
};
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -24,10 +24,22 @@ namespace detail {
* and passed by `args`
*/
template <typename T, typename... ARGS>
inline T &Ref(T *ptr, ARGS &&... args) {
inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE(ptr != nullptr, args...);
return *ptr;
}
template <typename T, typename... ARGS>
inline std::vector<std::reference_wrapper<T>> VectorRef(
const std::vector<T*>& vec, ARGS&&... args) {
std::vector<std::reference_wrapper<T>> result;
result.reserve(vec.size());
for (auto* ptr : vec) {
result.emplace_back(Ref(ptr, args...));
}
return result;
}
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -27,7 +27,7 @@ template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
......@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
size_t num = outputs->size();
......@@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
}
}
};
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
template class ConcatFunctor<platform::CPUDeviceContext, int>;
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace operators
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -118,7 +119,7 @@ template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int in_num = input.size();
......@@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
const std::vector<const framework::Tensor*>& ref_inputs,
int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
int o_num = outputs->size();
int out_row = 1;
......@@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
}
};
template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CUDADeviceContext, type>; \
template class ConcatGradFunctor<platform::CUDADeviceContext, type>
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace operators
......
......@@ -37,7 +37,7 @@ template <typename DeviceContext, typename T>
class ConcatFunctor {
public:
void operator()(const DeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output);
};
......@@ -57,10 +57,21 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs);
const std::vector<const framework::Tensor*>& ref_inputs,
int axis, std::vector<framework::Tensor*>* outputs);
};
} // namespace math
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h"
#include <vector>
namespace paddle {
namespace operators {
class SequenceConcatOp : public framework::OperatorWithKernel {
class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"),
"Inputs(X) of SequenceConcatOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceConcatOp should not be null.");
const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level"));
const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE(level == 0UL || level == 1UL,
"The sequence_concat operator only accepts sequence "
"or a nested sequence as its input.");
auto ins_dims = ctx->GetInputsDim("X");
framework::DDim out_dims = ins_dims[0];
const size_t n = ins_dims.size();
for (size_t i = 1; i < n; ++i) {
out_dims[axis] += ins_dims[i][axis];
}
ctx->SetOutputDim("Out", out_dims);
void Make() override {
AddInput("X", "The inputs of sequence concat op").AsDuplicable();
AddOutput("Out", "The output of sequence concat op");
AddComment(
"Sequence Concat Op\n"
"It will concat LoD tensors by its sequence information.\n"
"For example:\n"
" LoD of X1 = [0, 3, 7]\n"
" LoD of X2 = [0, 7, 9]\n"
" Result LoD is [0, (3+7), (7+9)]\n"
" i.e.[0, 10, 16]\n");
}
};
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
class SeqConcatShapeInferer : public framework::InferShapeBase {
public:
void Make() override {
AddInput("X",
"(LodTensorArray) Input is a vector of LoDTensor, "
"each of which is a variable-length sequence or nested sequence.")
.AsDuplicable();
AddOutput("Out",
"(LoDTensor), Variable-length output of "
"sequence_concat Op.");
AddAttr<int>("axis",
"(int, default 0) "
"The axis along which the inputs will be joined. "
"If axis is 0, the inputs will be joined with LoD index.")
.SetDefault(0);
AddAttr<int>("level",
"(int, default 0) "
"The level at which the inputs will be joined. "
"If the level is 0, the inputs will be joined at the nested "
"sequence level. "
"If the level is 1, the inputs will be joined at the "
"sequence level. "
"The level should be less than the level number of inputs.")
.SetDefault(0);
AddComment(R"DOC(
The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequence (LoD Tensor with level number is 1)
or a nested sequence (LoD tensor with level number is 2) as its input.
- Case1:
If the axis is other than 0(here, axis is 1 and level is 1),
each input should have the same LoD information and the LoD
information of the output keeps the same as the input.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4)
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4)
- Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute.
The LoD information of level-1 should be same.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,2,4}, {0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,2,4}, {0,2,5,8,11}}; Dims(Out) = (11,3,4)
- Case3:
If the axis is 0(here, level is 1).
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,4}, {0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,5,8}, {0,1,2,3,5,7,8,9,11}}; Dims(Out) = (11,3,4)
- Case4:
If the LoD number is 1, axis is 0, level is 0
LoD(x0) = {{0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,2,5,8,11}}; Dims(Out) = (11,3,4)
NOTE: The levels of all the inputs should be the same.
)DOC");
void operator()(framework::InferShapeContext *context) const override {
try {
PADDLE_ENFORCE(context->HasInputs("X"));
PADDLE_ENFORCE(context->HasOutput("Out"));
auto x_dims = context->GetInputsDim("X");
int64_t batch_size = 0;
int64_t feature_size = 0;
std::vector<int64_t> out_dims;
for (auto &x_dim : x_dims) {
if (out_dims.empty()) {
out_dims = framework::vectorize(x_dim);
}
batch_size += x_dim[0];
if (feature_size == 0) {
feature_size = framework::product(x_dim) / x_dim[0];
} else {
PADDLE_ENFORCE_EQ(
feature_size, framework::product(x_dim) / x_dim[0],
"Inputs of sequence concat must have same feature size");
}
}
if (batch_size < 0) {
batch_size = -1; // Normalize batch size for compile time.
}
out_dims[0] = batch_size;
context->SetOutputDim("Out", framework::make_ddim(out_dims));
if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
// in Kernel.
context->ShareLoD("X", "Out");
}
} catch (...) {
PADDLE_THROW("Unknown error");
}
}
};
class SequenceConcatGradOp : public framework::OperatorWithKernel {
class SeqConcatGradShapeInferer : public framework::InferShapeBase {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_concat, ops::SequenceConcatOp,
ops::SequenceConcatOpMaker,
paddle::framework::DefaultGradOpDescMaker<
false> /* set false to disable empty grad */);
REGISTER_OPERATOR(sequence_concat_grad, ops::SequenceConcatGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
sequence_concat_grad,
ops::SequenceConcatGradOpKernel<paddle::platform::CPUDeviceContext, float>);
namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
paddle::framework::DefaultGradOpDescMaker<false>);
template <typename T>
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
op::SeqConcatGradShapeInferer);
template <typename T>
using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(sequence_concat_grad,
ops::SequenceConcatGradOpKernel<
paddle::platform::CUDADeviceContext, float>);
template <typename T>
using Kernel =
paddle::operators::SeqConcatKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_CUDA_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
template <typename T>
using GradKernel =
paddle::operators::SeqConcatGradKernel<paddle::platform::CUDADeviceContext,
T>;
REGISTER_OP_CUDA_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename T>
LoD ConcatLoD(const std::vector<const T*> ins, const size_t level) {
auto out_lod = ins[0]->lod();
auto numLevels = ins[0]->NumLevels();
const size_t n = ins.size();
const size_t level_idx = ins[0]->NumLevels() - 1 - level;
for (size_t i = 1; i < n; ++i) {
for (size_t j = 0; j < ins[i]->lod()[level_idx].size(); ++j) {
out_lod[level_idx][j] += ins[i]->lod()[level_idx][j];
}
}
for (size_t i = level_idx; i < numLevels - 1; ++i) {
size_t lod_len = 1;
for (size_t j = 0; j < n; ++j) {
lod_len += ins[j]->lod()[i + 1].size() - 1;
}
out_lod[i + 1].clear();
out_lod[i + 1].resize(lod_len);
size_t idx = 1;
for (size_t j = 0; j < ins[0]->lod()[i].size() - 1; ++j) {
for (size_t k = 0; k < n; ++k) {
for (size_t m = ins[k]->lod()[i][j]; m < ins[k]->lod()[i][j + 1]; ++m) {
out_lod[i + 1][idx] = out_lod[i + 1][idx - 1] +
ins[k]->lod()[i + 1][m + 1] -
ins[k]->lod()[i + 1][m];
idx++;
}
}
}
}
return out_lod;
namespace detail {
template <typename Container>
inline framework::LoD ConcatLoD(const Container &xs,
std::vector<framework::Tensor> *xs_in_order) {
std::vector<size_t> result;
result.resize(xs[0].get().lod()[0].size());
for (size_t i = 1; i < result.size(); ++i) {
size_t sum = 0;
for (size_t j = 0; j < xs.size(); ++j) {
auto &x_lod = xs[j].get().lod()[0];
const framework::Tensor &tensor = xs[j].get();
xs_in_order->emplace_back(tensor.Slice(x_lod[i - 1], x_lod[i]));
sum += x_lod[i];
}
result[i] = sum;
}
framework::LoD lod;
lod.emplace_back(result);
return lod;
}
} // namespace detail
template <typename DeviceContext, typename T>
class SequenceConcatOpKernel : public framework::OpKernel<T> {
class SeqConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
const size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = ins.size();
for (size_t i = 1; i < n; ++i) {
PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), ins[i]->NumLevels(),
"The levels of all the input LoDTensors "
"should be the same.");
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), ins[i]->dims().size(),
"The dimension size of all the input LoDTensors "
"should be the same.");
const size_t dims_size = ins[i]->dims().size();
for (size_t j = 0; j < dims_size; ++j) {
if (j == axis) continue;
PADDLE_ENFORCE_EQ(ins[0]->dims()[j], ins[i]->dims()[j],
"Except for the dimension of the specified "
"axis along which all the inputs are concatenated, "
"dimensions of all the other axises of the input "
"LoDTensors should be the same.");
}
}
PADDLE_ENFORCE_GT(ins[0]->NumLevels(), level,
"The levels of all the input LoDTensors "
"should be greater than the specify level");
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = ins[0]->lod();
if (axis == 0) {
out_lod = ConcatLoD<LoDTensor>(ins, level);
}
out->set_lod(out_lod);
const size_t level_idx = out_lod.size() - level - 1;
auto out_lod_level = framework::ToAbsOffset(out_lod)[level_idx];
for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_t = out->Slice(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_stride = framework::stride(out_t.dims());
size_t offset = 0;
for (size_t j = 0; j < n; ++j) {
auto in_lod_level = framework::ToAbsOffset(ins[j]->lod())[level_idx];
auto in_stride = framework::stride(ins[j]->dims());
Tensor in_t = ins[j]->Slice(static_cast<int>(in_lod_level[i]),
static_cast<int>(in_lod_level[i + 1]));
size_t axis_dim = in_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), in_stride,
in_t.dims(), out_stride, out_t.data<T>() + offset);
offset += axis_dim * in_stride[axis];
}
}
void Compute(const framework::ExecutionContext &context) const override {
auto xs = detail::VectorRef(context.MultiInput<framework::LoDTensor>("X"),
"Cannot find multiple input X");
auto &out = detail::Ref(context.Output<framework::LoDTensor>("Out"),
"Cannot find output");
size_t lod_size = 0;
for (auto &x : xs) {
if (lod_size == 0) {
lod_size = x.get().lod()[0].size();
} else {
PADDLE_ENFORCE_EQ(
lod_size, x.get().lod()[0].size(),
"The number of sequence must be same between each input");
}
}
PADDLE_ENFORCE_NE(lod_size, 0, "Each input must have sequence information");
std::vector<framework::Tensor> x_in_order;
out.set_lod(detail::ConcatLoD(xs, &x_in_order));
out.mutable_data<T>(context.GetPlace());
math::ConcatFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), x_in_order, 0,
&out);
}
};
template <typename DeviceContext, typename T>
class SequenceConcatGradOpKernel : public framework::OpKernel<T> {
class SeqConcatGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grads =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = x_grads.size();
void Compute(const framework::ExecutionContext &context) const override {
auto xs = context.MultiInput<framework::LoDTensor>("X");
auto dxs =
context.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(xs.size(), dxs.size());
for (size_t i = 0; i < dxs.size(); ++i) {
if (dxs[i] != nullptr) {
dxs[i]->set_lod(xs[i]->lod());
dxs[i]->mutable_data<T>(context.GetPlace());
}
}
std::vector<framework::Tensor> sliced_x;
std::vector<boost::variant<boost::blank, framework::Tensor>> sliced_dx;
// Set Grad(X) LoD as X
for (size_t i = 0; i < n; i++) {
x_grads[i]->set_lod(ins[i]->lod());
x_grads[i]->mutable_data<T>(ctx.GetPlace());
for (size_t i = 1; i < xs[0]->lod()[0].size(); ++i) {
for (size_t j = 0; j < xs.size(); ++j) {
const framework::LoDTensor *x = xs[j];
framework::LoDTensor *dx = dxs[j];
auto &x_lod = x->lod()[0];
sliced_x.emplace_back(x->Slice(x_lod[i - 1], x_lod[i]));
if (dx != nullptr) {
sliced_dx.emplace_back(dx->Slice(x_lod[i - 1], x_lod[i]));
} else {
sliced_dx.emplace_back(boost::blank());
}
}
auto out_lod = ins[0]->lod();
if (axis == 0UL) {
out_lod = ConcatLoD<LoDTensor>(ins, level);
}
const size_t level_idx = out_lod.size() - level - 1;
auto out_lod_level = framework::ToAbsOffset(out_lod)[level_idx];
for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
size_t offset = 0;
math::ConcatGradFunctor<DeviceContext, T> functor;
std::vector<const framework::Tensor *> sliced_x_ptr;
std::vector<framework::Tensor *> sliced_dx_ptr;
for (auto &x : sliced_x) {
sliced_x_ptr.emplace_back(&x);
}
for (size_t j = 0; j < n; ++j) {
auto x_grad_lod_level =
framework::ToAbsOffset(x_grads[j]->lod())[level_idx];
auto x_grad_stride = framework::stride(x_grads[j]->dims());
Tensor x_grad_t =
x_grads[j]->Slice(static_cast<int>(x_grad_lod_level[i]),
static_cast<int>(x_grad_lod_level[i + 1]));
size_t axis_dim = x_grad_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>() + offset,
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
offset += axis_dim * out_grad_stride[axis];
for (auto &dx : sliced_dx) {
try {
sliced_dx_ptr.emplace_back(&boost::get<framework::Tensor>(dx));
} catch (boost::bad_get &) {
sliced_dx_ptr.emplace_back(nullptr);
}
}
functor(context.template device_context<DeviceContext>(),
detail::Ref(
context.Input<framework::Tensor>(framework::GradVarName("Out")),
"Sequence Concat OG must be set"),
sliced_x_ptr, 0, &sliced_dx_ptr);
}
};
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestSequenceConcat(OpTest):
def setUp(self):
x1 = np.random.random(size=(10, 80))
lod1 = [7, 3]
x2 = np.random.random(size=(20, 80))
lod2 = [12, 8]
out = np.concatenate((x1[0:lod1[0]], x2[0:lod2[0]], x1[lod1[0]:],
x2[lod2[0]:]))
out_lod = [19, 11]
self.op_type = "sequence_concat"
self.inputs = {'X': [("x1", (x1, [lod1])), ("x2", (x2, [lod2]))]}
self.outputs = {"Out": (out, [out_lod])}
def test_output(self):
self.check_output(1e-3)
def test_dx(self):
self.check_grad(inputs_to_check=['x1', 'x2'], output_names="Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册