diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 11890b30ae598f149c08674d888dcb16d5fa2c2d..5cab367988dd26c6e0af77292bc6195c62997bfa 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include // std::itoa #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; +template +using EigenMatrix = framework::EigenMatrix; template struct SequenceExpandFunctor { @@ -38,23 +42,35 @@ template struct SequenceExpandFunctor { void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x, LoDTensor* out) { - auto x_dims = x.dims(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - const T* x_data = x.data(); - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); - - for (size_t i = 0; i < out_starts.size() - 1; i++) { - int scale = out_starts[i + 1] - out_starts[i]; - Eigen::TensorMap< - Eigen::Tensor> - x_t(x_data, 1, element_len); - Eigen::TensorMap> - out_t(out_data, scale, element_len); - Eigen::array cast({{scale, 1}}); - out_t.device(*context.eigen_device()) = x_t.broadcast(cast); - x_data += element_len; - out_data += element_len * scale; + auto& out_lod = out->lod()[0]; + framework::Vector x_lod; + if (x.lod() == 1) { + x_lod = x.lod()[0]; + } else { + x_lod.reserve(out_lod.size()); + std::itoa(x_lod.begin(), x_lod.end(), 0); // fill 0 ~ out_lod.size()-1 + } + int out_offset = 0; + auto& eigen_place = *context.eigen_device(); + for (size_t i = 1; i < out_lod.size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_start = x_lod[i - 1]; + int x_end = x_lod[i]; + int x_seq_len = x_end - x_start; + if (repeat_num > 0) { + auto x_sub_tensor = x->Slice(x_start, x_end); + x_sub_tensor.Resize({1, x_sub_tensor.numel()}); + int out_start = out_offset; + if (x_lod.size() == 1) { + out_start = out_lod[0][out_offset]; + } + auto out_sub_tensor = + out->Slice(out_start, out_start + x_seq_len * repeat_num); + out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); + EigenMatrix::From(out_sub_tensor).device(eigen_place) = + EigenMatrix::From(x_sub_tensor) + .broadcast(Eigen::array({{repeat_num, 1}})); + } } } }; @@ -64,15 +80,42 @@ class SequenceExpandKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); - auto* out = context.Output("Out"); - auto x_dims = x->dims(); auto* y = context.Input("Y"); - PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); - PADDLE_ENFORCE_EQ(static_cast(x_dims[0]), - y->lod().back().size() - 1, - "The size of last lod level in Input(Y)" - "must be equal to dims[0] of Input(X)."); - out->set_lod(y->lod()); + auto* out = context.Output("Out"); + + int ref_level = context.Attr("ref_level"); + auto& x_lod = x->lod(); + auto& y_lod = y->lod(); + + if (ref_level == -1) ref_level = y_lod.size() - 1; + + out->mutable_data(context.GetPlace()); + + if (y_lod[ref_level].size() <= 1) { + framework::TensorCopy(*x, context.GetPlace(), out); + return; + } + + auto& out_lod = *out->mutable_lod(); + // x lod level is at most 1. + if (x_lod.size() == 0) { + out_lod = y_lod[ref_level]; + } else if (x_lod.size() == 1) { + out_lod.resize(1); + out_lod[0] = {0}; + int out_offset = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + int x_start = x_lod[0][i - 1]; + int x_end = x_lod[0][i]; + int x_seq_len = x_end - x_start; + for (int j = 0; j < repeat_num; ++j) { + out_lod[0].push_back(out_lod[0].back() + x_seq_len); + out_offset++; + } + } + } + SequenceExpandFunctor functor; functor(context.template device_context(), *x, out); } @@ -94,21 +137,31 @@ template struct SequenceExpandGradFunctor { void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x, const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { - auto out_last_level = out.lod().back(); - const T* d_out_data = dout.data(); - T* d_x_data = dx->mutable_data(context.GetPlace()); - size_t element_len = dout.numel() / dout.dims()[0]; - for (size_t i = 0; i < out_last_level.size() - 1; ++i) { - size_t repeat = out_last_level[i + 1] - out_last_level[i]; - Eigen::TensorMap< - Eigen::Tensor> - d_out_t(d_out_data, static_cast(repeat), element_len); - Eigen::TensorMap> - d_x_t(d_x_data, static_cast(element_len)); - d_x_t.device(*context.eigen_device()) = - d_out_t.sum(Eigen::array({{0}})); - d_out_data += (repeat * element_len); - d_x_data += element_len; + auto& dev_ctx = context.template device_context(); + + math::SetConstant set_zero; + set_zero(dev_ctx, g_x, static_cast(0)); + + int g_out_offset = 0; + for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { + int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; + if (repeat_num > 0) { + int x_start = i - 1; + int x_end = i; + if (x_lod.size() == 1) { + x_start = x_lod[0][i - 1]; + x_end = x_lod[0][i]; + } + int x_seq_len = x_end - x_start; + auto g_x_sub = g_x->Slice(x_start, x_end); + g_x_sub.Resize(flatten_to_1d(g_x_sub.dims())); + int g_out_end = g_out_offset + repeat_num * x_seq_len; + auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); + g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]}); + math::ColwiseSum col_sum; + col_sum(dev_ctx, g_out_sub, &g_x_sub); + g_out_offset += repeat_num * x_seq_len; + } } } }; @@ -117,15 +170,29 @@ template class SequenceExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* g_out = context.Input(framework::GradVarName("Out")); auto* x = context.Input("X"); - auto* out = context.Input("Out"); - auto* d_out = context.Input(framework::GradVarName("Out")); + auto* y = context.Input("Y"); + auto* g_x = context.Output(framework::GradVarName("X")); + int ref_level = context.Attr("ref_level"); + + g_x->mutable_data(context.GetPlace()); + g_x->set_lod(x->lod()); + + auto& x_lod = x->lod(); + auto& y_lod = y->lod(); + + if (ref_level == -1) ref_level = y_lod.size() - 1; + + // just copy the gradient + if (y_lod[ref_level].size() <= 1) { + framework::TensorCopy(*g_out, context.GetPlace(), g_x); + return; + } - auto* d_x = context.Output(framework::GradVarName("X")); - d_x->set_lod(x->lod()); SequenceExpandGradFunctor functor; - functor(context.template device_context(), *x, *out, *d_out, - d_x); + functor(context.template device_context(), *x, *y, *g_out, + g_x); } };