// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "lite/backends/x86/math/blas.h" #include "lite/backends/x86/math/context_project.h" #include "lite/backends/x86/math/math_function.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { namespace x86 { namespace math = paddle::lite::x86::math; template class SequenceConvCompute : public KernelLite { public: using param_t = operators::SequenceConvParam; void Run() override { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); auto* in = param.X; auto* filter = param.Filter; auto* out = param.Out; out->template mutable_data(); CHECK(in->lod().size() == 1) << "Only support one level sequence now"; int context_start = param.contextStart; int context_stride = param.contextStride; int context_length = param.contextLength; bool padding_trainable = false; const Tensor* padding_data = nullptr; int up_pad = (std::max)(0, -context_start); int down_pad = (std::max)(0, context_start + context_length - 1); auto sequence_width = static_cast(in->dims()[1]); std::vector col_shape{in->dims()[0], context_length * sequence_width}; Tensor col; col.Resize(col_shape); col.mutable_data(); // Because if padding_trainable is false, padding data should be zeros. math::SetConstant set_zero; auto blas = math::GetBlas(ctx); set_zero(ctx, &col, static_cast(0)); math::ContextProjectFunctor seq_project_functor; seq_project_functor(ctx, *in, padding_data, padding_trainable, context_start, context_length, context_stride, up_pad, down_pad, &col); blas.MatMul(col, *filter, out); } virtual ~SequenceConvCompute() = default; }; } // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle