From 9bd9d8b5ca96ff442a7ba3a3df0564e414c11af5 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 18 Jan 2018 11:29:09 +0800 Subject: [PATCH] Add sequence_reshape_op. --- paddle/operators/sequence_reshape_op.cc | 78 +++++++++++++++ paddle/operators/sequence_reshape_op.h | 127 ++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 paddle/operators/sequence_reshape_op.cc create mode 100644 paddle/operators/sequence_reshape_op.h diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc new file mode 100644 index 00000000000..31a970354f2 --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" + +namespace paddle { +namespace operators { + +class SequenceReshapeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceReshapeOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); + int dimension = ctx->Attrs().Get("dimension"); + ctx->SetOutputDim("Out", {{x_dims[0], static_cast(dimension)}}); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", ""); + AddOutput("Out", ""); + AddAttr("dimension", ""); + AddAttr("is_padding", "Default padding zero."); + AddComment(R"DOC()DOC"); + } +}; + +class SequenceReshapeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceReshapeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input(Out) of SequenceReshapeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeGradOp should not be null."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp, + ops::SequenceReshapeOpMaker); +REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel); +REGISTER_OP_CPU_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h new file mode 100644 index 00000000000..bc7694b6b1b --- /dev/null +++ b/paddle/operators/sequence_reshape_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +template +class SequenceReshapeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int out_width = context.Attr("dimension"); + bool whether_padding = context.Attr("whether_padding"); + + const T* p_in_data = in->data(); + T* p_out_data = out->mutable_data(context.GetPlace()); + + // compute shape for output + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + auto& in_lod = in->lod(); + + PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_GE( + in_dims[0], + /* batch size = */ static_cast(in_lod[0].size() - 1), + "The 1st dimension of Input(X) must be equal or larger than batch " + "size."); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + auto& out_lod = *out->mutable_lod(); + out_lod.push_back(std::vector({0})); + size_t offset = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + if (whether_padding) { + offset += std::ceil((float)(seq_len * in_width) / out_width); + } else { + offset += (seq_len * in_width) / out_width; + } + out_lod[0].push_back(offset); + } + + out->Resize({{static_cast(out_lod[0].back()), out_width}}); + math::set_constant(context.device_context(), out, 0.0f); + + for (int i = 0; i < seq_num; ++i) { + size_t in_offset = in_lod_l0[i] * in_width; + size_t out_offset = out_lod[0][i] * out_width; + size_t bytes = sizeof(T) * (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width; + if (platform::is_cpu_place(context.GetPlace())) { + std::memcpy(p_out_data + out_offset, p_in_data + in_offset, bytes); + } else { +#ifdef PADDLE_WITH_CUDA + auto& dev_ctx = context.template device_context(); + memory::Copy(boost::get(context.GetPlace()), + p_out_data + out_offset, + boost::get(context.GetPlace()), + p_in_data + in_offset, bytes, dev_ctx.stream()); +#endif + } + } + } +}; + +template +class SequenceReshapeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_tensor_ptr = context.Input("X"); + auto* out_tensor_ptr = context.Input("Out"); + auto* out_grad_tensor_ptr = + context.Input(framework::GradVarName("Out")); + auto* x_grad_tensor_ptr = + context.Output(framework::GradVarName("X")); + + T* p_x_grad_data = x_grad_tensor_ptr->mutable_data(context.GetPlace()); + const T* p_out_grad_data = out_grad_tensor_ptr->data(); + + auto& x_lod = x_tensor_ptr->lod(); + int seq_num = x_lod[0].size() - 1; + int x_width = x_tensor_ptr->dims()[1]; + auto& out_lod = out_tensor_ptr->lod(); + int out_width = out_tensor_ptr->dims()[1]; + + for (int i = 0; i < seq_num; ++i) { + size_t src_offset = out_lod[0][i] * out_width; + size_t dst_offset = x_lod[0][i] * x_width; + size_t bytes = sizeof(T) * (x_lod[0][i + 1] - x_lod[0][i]) * x_width; + if (platform::is_cpu_place(context.GetPlace())) { + std::memcpy(p_x_grad_data + dst_offset, p_out_grad_data + src_offset, + bytes); + } else { +#ifdef PADDLE_WITH_CUDA + auto& dev_ctx = context.template device_context(); + memory::Copy(boost::get(context.GetPlace()), + p_x_grad_data + dst_offset, + boost::get(context.GetPlace()), + p_out_grad_data + src_offset, bytes, dev_ctx.stream()); +#endif + } + } + } +}; + +} // namespace operators +} // namespace paddle -- GitLab