From eb42f9ee956a4ceaee91a7a4e2564d618e1594d0 Mon Sep 17 00:00:00 2001 From: lhl960107 <649069257@qq.com> Date: Mon, 16 Sep 2019 15:35:00 +0800 Subject: [PATCH] Gru op (#2002) * add x86 gru&&relu&&sequence_expand_as op test=develop --- .../x86/math/detail/activation_functions.h | 3 +- lite/backends/x86/math/sequence2batch.h | 2 +- lite/kernels/x86/CMakeLists.txt | 9 +- lite/kernels/x86/activation_compute.cc | 100 +------- lite/kernels/x86/activation_compute.h | 120 ++++++++++ lite/kernels/x86/activation_compute_test.cc | 83 +++++++ lite/kernels/x86/gru_compute.cc | 36 +++ lite/kernels/x86/gru_compute.h | 221 ++++++++++++++++++ lite/kernels/x86/gru_compute_test.cc | 155 ++++++++++++ lite/kernels/x86/relu_compute_test.cc | 4 +- .../kernels/x86/sequence_expand_as_compute.cc | 26 +++ lite/kernels/x86/sequence_expand_as_compute.h | 81 +++++++ .../x86/sequence_expand_as_compute_test.cc | 96 ++++++++ lite/operators/CMakeLists.txt | 1 + lite/operators/op_params.h | 6 + lite/operators/sequence_expand_as_op.cc | 76 ++++++ lite/operators/sequence_expand_as_op.h | 47 ++++ 17 files changed, 966 insertions(+), 100 deletions(-) create mode 100644 lite/kernels/x86/activation_compute.h create mode 100644 lite/kernels/x86/activation_compute_test.cc create mode 100644 lite/kernels/x86/gru_compute.cc create mode 100644 lite/kernels/x86/gru_compute.h create mode 100644 lite/kernels/x86/gru_compute_test.cc create mode 100644 lite/kernels/x86/sequence_expand_as_compute.cc create mode 100644 lite/kernels/x86/sequence_expand_as_compute.h create mode 100644 lite/kernels/x86/sequence_expand_as_compute_test.cc create mode 100644 lite/operators/sequence_expand_as_op.cc create mode 100644 lite/operators/sequence_expand_as_op.h diff --git a/lite/backends/x86/math/detail/activation_functions.h b/lite/backends/x86/math/detail/activation_functions.h index cb215df722..d12b1594d0 100644 --- a/lite/backends/x86/math/detail/activation_functions.h +++ b/lite/backends/x86/math/detail/activation_functions.h @@ -45,7 +45,8 @@ inline ActivationType GetActivationType(const std::string &type) { } else if (type == "identity" || type == "") { return ActivationType::kIdentity; } - PADDLE_ENFORCE(false, "Not support type %s", type); + LOG(ERROR) << "Not support type " << type; + // PADDLE_ENFORCE(false, "Not support type %s", type); // PADDLE_THROW("Not support type %s.", type); } diff --git a/lite/backends/x86/math/sequence2batch.h b/lite/backends/x86/math/sequence2batch.h index 807558e9d8..a97bfaf666 100644 --- a/lite/backends/x86/math/sequence2batch.h +++ b/lite/backends/x86/math/sequence2batch.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "lite/core/context.h" #include "lite/core/tensor.h" #include "lite/fluid/eigen.h" -#include "lite/fluid/lod.h" +// #include "lite/fluid/lod.h" #include "lite/utils/paddle_enforce.h" namespace paddle { diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 4b50ec1f0d..c8da4fc669 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -1,4 +1,4 @@ -# lite_cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) +add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops) # lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) @@ -18,13 +18,15 @@ add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_ker # lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) # lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} ) +add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute) +#add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) # lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) # lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) # lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) -# lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS relu_compute_x86) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) @@ -49,4 +51,7 @@ lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc D lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) +lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) +lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) diff --git a/lite/kernels/x86/activation_compute.cc b/lite/kernels/x86/activation_compute.cc index 94d877de28..0ed09c43a5 100644 --- a/lite/kernels/x86/activation_compute.cc +++ b/lite/kernels/x86/activation_compute.cc @@ -12,94 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/core/kernel.h" -#include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/activation_op.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace x86 { - -template -void Activate(const platform::CPUDeviceContext& context, - const framework::LoDTensor* X, - framework::LoDTensor* Out) { - using T = typename Functor::ELEMENT_TYPE; - auto* place = context.eigen_device(); - auto x = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(X)); - auto out = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(Out)); - Functor()(*place, x, out); -} - -template -void ActivateGrad(const platform::CPUDeviceContext& context, - const framework::LoDTensor* X, - const framework::LoDTensor* Out, - const framework::LoDTensor* Out_grad, - framework::LoDTensor* X_grad) { - using T = typename Functor::ELEMENT_TYPE; - auto* place = context.eigen_device(); - auto x = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(X)); - auto out = - framework::EigenVector::Flatten(paddle::operators::detail::Ref(Out)); - auto x_grad = framework::EigenVector::Flatten( - paddle::operators::detail::Ref(X_grad)); - auto out_grad = framework::EigenVector::Flatten( - paddle::operators::detail::Ref(Out_grad)); - Functor()(*place, x, out, out_grad, x_grad); -} - -template -class SquareCompute : public KernelLite { - public: - using param_t = operators::ActivationParam; - - void Run() override { - auto& context = ctx_->As(); - auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context()); - - param.Out->template mutable_data(); - Activate>(*context.x86_device_context(), - ¶m.X->raw_tensor(), - ¶m.Out->raw_tensor()); - } - - virtual ~SquareCompute() = default; -}; - -template -class SquareGradCompute : public KernelLite { - public: - using param_t = operators::ActivationGradParam; - - void Run() override { - auto& context = ctx_->As(); - auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context()); - param.X_grad->template mutable_data(); - - ActivateGrad>( - *context.x86_device_context(), - ¶m.X->raw_tensor(), - ¶m.Out->raw_tensor(), - ¶m.Out_grad->raw_tensor(), - ¶m.X_grad->raw_tensor()); - } - - virtual ~SquareGradCompute() = default; -}; - -} // namespace x86 -} // namespace kernels -} // namespace lite -} // namespace paddle +#include "lite/kernels/x86/activation_compute.h" // float REGISTER_LITE_KERNEL(square, @@ -112,16 +25,13 @@ REGISTER_LITE_KERNEL(square, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); -REGISTER_LITE_KERNEL(square_grad, +// float +REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW, - paddle::lite::kernels::x86::SquareGradCompute, + paddle::lite::kernels::x86::ReluCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput(paddle::framework::GradVarName("Out"), - {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput(paddle::framework::GradVarName("X"), - {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/lite/kernels/x86/activation_compute.h b/lite/kernels/x86/activation_compute.h new file mode 100644 index 0000000000..105bc70e7a --- /dev/null +++ b/lite/kernels/x86/activation_compute.h @@ -0,0 +1,120 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/fluid/eigen.h" +#include "lite/operators/activation_ops.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +enum ActBwdOpFwdDeps { + kNoDeps = 0x00, // Do not need any forward input/output + kDepX = 0x01, // Only need forward input X + kDepOut = 0x02, // Only need forward output Out + + // Never add kDepXOut, because Out can be always calculated + // by forward input X in backward part. + // FIXME(zjl): but in MKLDNN abs, X and Out are all needed... + // Developers should not rely on this enum value! + kDepXOut = 0x03 +}; + +template +struct BaseActivationFunctor { + using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } + + /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient. + For example, sigmoid op's gradient didn't involve x, so its output can + reuse + input memory. But abs op's gradient use x, it can not be inplaced. + gradient did use x. + */ + bool Inplace() const { return false; } +}; + +template +bool Activate(const lite::Tensor* X, lite::Tensor* Out) { + using T = typename Functor::ELEMENT_TYPE; + auto place = lite::fluid::EigenDeviceType(); + CHECK_OR_FALSE(X) + CHECK_OR_FALSE(Out) + auto x = lite::fluid::EigenVector::Flatten(*X); + auto out = lite::fluid::EigenVector::Flatten(*Out); + Functor()(place, x, out); +} + +// square(x) = x^2 +template +struct SquareFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.square(); + } +}; + +template +class SquareCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~SquareCompute() = default; +}; + +// relu(x) = max(x, 0) +template +struct ReluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)); + } +}; + +template +class ReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~ReluCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/activation_compute_test.cc b/lite/kernels/x86/activation_compute_test.cc new file mode 100644 index 0000000000..8cc2607e73 --- /dev/null +++ b/lite/kernels/x86/activation_compute_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/activation_compute.cc" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(relu_x86, retrive_op) { + auto relu = + KernelRegistry::Global().Create("relu"); + ASSERT_FALSE(relu.empty()); + ASSERT_TRUE(relu.front()); +} + +TEST(relu_x86, init) { + ReluComputeCompute relu; + ASSERT_EQ(relu.precision(), PRECISION(kFloat)); + ASSERT_EQ(relu.target(), TARGET(kX86)); +} + +TEST(relu_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector out_shape{batch_size, 3, 2, 2}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + int sign = i % 2 == 0 ? 1 : -1; + x_data[i] = static_cast(i * sign); + } + + // ReluCompute relu; + ReluCompute relu; + operators::Param param; + + param.x = &x; + param.y = &y; + param.out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sequence_expand_as.SetContext(std::move(ctx)); + sequence_expand_as.SetParam(param); + sequence_expand_as.Run(); + auto out_data = out.mutable_data(); + + LOG(INFO) << "output: "; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_expand_as, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/gru_compute.cc b/lite/kernels/x86/gru_compute.cc new file mode 100644 index 0000000000..c1b6c2caa9 --- /dev/null +++ b/lite/kernels/x86/gru_compute.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gru_compute.h" + +DEFINE_int32(paddle_num_threads, + 1, + "Number of threads for each paddle instance."); + +REGISTER_LITE_KERNEL(gru, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::GRUCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("H0", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Batch_gate", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Batch_reset_hidden_prev", + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Batch_hidden", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/gru_compute.h b/lite/kernels/x86/gru_compute.h new file mode 100644 index 0000000000..e3c6f70fdb --- /dev/null +++ b/lite/kernels/x86/gru_compute.h @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/detail/gru_cpu_kernel.h" +#include "lite/backends/x86/math/detail/gru_kernel.h" +#include "lite/backends/x86/math/gru_compute.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/backends/x86/math/sequence2batch.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" + +DECLARE_int32(paddle_num_threads); + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +using Tensor = lite::Tensor; + +template +inline void ReorderInitState(const lite::Context& context, + const Tensor& src, + const std::vector& index_lod, + Tensor* dst, + bool indexed_src) { + lite::x86::math::CopyMatrixRowsFunctor row_shuffle; + dst->Resize(src.dims()); + dst->mutable_data(); + row_shuffle(context, src, index_lod, dst, indexed_src); +} + +template +class GRUCompute : public KernelLite { + public: + void Run() override { + auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + + bool origin_mode = param.origin_mode; + bool is_reverse = param.is_reverse; + + auto* input = param.input; + auto* h0 = param.h0; + auto* weight = param.weight; + const T* weight_data = weight->data(); + auto* bias = param.bias; + + auto* batch_gate = param.batch_gate; + batch_gate->mutable_data(); + auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev; + batch_reset_hidden_prev->mutable_data(); + auto* batch_hidden = param.batch_hidden; + batch_hidden->mutable_data(); + auto* hidden = param.hidden; + hidden->mutable_data(); + + auto hidden_dims = hidden->dims(); + + lite::x86::math::LoDTensor2BatchFunctor to_batch; + to_batch(context, *input, batch_gate, true, is_reverse); + + if (bias) { + lite::x86::math::RowwiseAdd add_bias; + add_bias(context, *batch_gate, *bias, batch_gate); + } + + int frame_size = hidden_dims[1]; + lite::x86::math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + + std::vector order(batch_gate->lod()[2]); + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(context, *h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.mutable_data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t seq_len = batch_starts.size() - 1; + auto active_node = + lite::x86::math::detail::GetActivationType(param.activation); + auto active_gate = + lite::x86::math::detail::GetActivationType(param.gate_activation); + +#ifdef PADDLE_WITH_MKLML + // use MKL packed to speedup GEMM + if (FLAGS_paddle_num_threads >= 4) { + auto blas = lite::x86::math::GetBlas(context); + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + CHECK(packed_gate); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size * 2, + frame_size, + T(1.0), + gru_value.gate_weight, + frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, + 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + CHECK(packed_state); + blas.GEMM_PACK(CblasBMatrix, + CblasNoTrans, + 1 /*cur bs?*/, + frame_size, + frame_size, + T(1.0), + gru_value.state_weight, + frame_size, + packed_state); + for (size_t n = 0; n < seq_len; n++) { + int64_t bstart = static_cast(batch_starts[n]); + int64_t bend = static_cast(batch_starts[n + 1]); + int64_t cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.mutable_data(); + gru_value.gate_value = gate_t.mutable_data(); + gru_value.reset_output_value = reset_hidden_prev_t.mutable_data(); + + if (gru_value.prev_out_value) { + blas.GEMM_COMPUTE(CblasNoTrans, + CblasPacked, + cur_batch_size, + frame_size * 2, + frame_size, + gru_value.prev_out_value, + frame_size, + packed_gate, + frame_size * 2, + T(1), + gru_value.gate_value, + frame_size * 3); + } + + lite::x86::math::detail::forward_final_output( + lite::x86::math::detail::forward::gru_finalOutput(), + gru_value, + frame_size, + cur_batch_size, + active_node, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } + + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); + } else { +#endif + for (size_t n = 0; n < seq_len; n++) { + int64_t bstart = static_cast(batch_starts[n]); + int64_t bend = static_cast(batch_starts[n + 1]); + int64_t cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = + batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.mutable_data(); + gru_value.gate_value = gate_t.mutable_data(); + gru_value.reset_output_value = reset_hidden_prev_t.mutable_data(); + + lite::x86::math::GRUUnitFunctor::compute( + context, + gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode); + + gru_value.prev_out_value = gru_value.output_value; + } +#ifdef PADDLE_WITH_MKLML + } +#endif + lite::x86::math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(context, *batch_hidden, hidden); + } +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/gru_compute_test.cc b/lite/kernels/x86/gru_compute_test.cc new file mode 100644 index 0000000000..3e0e944f23 --- /dev/null +++ b/lite/kernels/x86/gru_compute_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gru_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(gru_x86, retrive_op) { + auto gru = + KernelRegistry::Global().Create("gru"); + ASSERT_FALSE(gru.empty()); + ASSERT_TRUE(gru.front()); +} + +TEST(gru_x86, init) { + GRUCompute gru; + ASSERT_EQ(gru.precision(), PRECISION(kFloat)); + ASSERT_EQ(gru.target(), TARGET(kX86)); +} + +TEST(gru_x86, run_test) { + lite::Tensor input, h0, weight, bias; + lite::Tensor batch_gate, batch_reset_hidden_prev, batch_hidden, hidden; + constexpr int batch_size = 9; + std::vector input_shape{batch_size, 15}; + input.Resize(lite::DDim(input_shape)); + std::vector weight_shape{5, 15}; + weight.Resize(lite::DDim(weight_shape)); + std::vector h0_shape{3, 5}; + h0.Resize(lite::DDim(h0_shape)); + std::vector bias_shape{1, 15}; + bias.Resize(lite::DDim(bias_shape)); + std::vector batch_gate_shape{batch_size, 15}; + batch_gate.Resize(lite::DDim(batch_gate_shape)); + std::vector batch_reset_hidden_prev_shape{batch_size, 5}; + batch_reset_hidden_prev.Resize(lite::DDim(batch_reset_hidden_prev_shape)); + std::vector batch_hidden_shape{batch_size, 5}; + batch_hidden.Resize(lite::DDim(batch_hidden_shape)); + std::vector hidden_shape{batch_size, 5}; + hidden.Resize(lite::DDim(hidden_shape)); + + std::vector> lod{{0, 2, 6, 9}}; + input.set_lod(lod); + + auto input_data = input.mutable_data(); + auto weight_data = weight.mutable_data(); + auto h0_data = h0.mutable_data(); + auto bias_data = bias.mutable_data(); + + for (int64_t i = 0; i < input.dims().production(); i++) { + input_data[i] = static_cast(0); + } + for (int64_t i = 0; i < weight.dims().production(); i++) { + weight_data[i] = static_cast(0); + } + for (int64_t i = 0; i < h0.dims().production(); i++) { + h0_data[i] = static_cast(0); + } + for (int64_t i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(0); + } + // ReluCompute relu; + GRUCompute gru; + operators::GRUParam param; + + param.input = &input; + param.h0 = &h0; + param.weight = &weight; + param.bias = &bias; + param.batch_gate = &batch_gate; + param.batch_reset_hidden_prev = &batch_reset_hidden_prev; + param.batch_hidden = &batch_hidden; + param.hidden = &hidden; + param.gate_activation = "sigmoid"; + param.activation = "tanh"; + param.is_reverse = false; + param.origin_mode = false; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gru.SetContext(std::move(ctx)); + gru.SetParam(param); + gru.Run(); + + auto batch_gate_data = batch_gate.mutable_data(); + auto batch_reset_hidden_prev_data = + batch_reset_hidden_prev.mutable_data(); + auto batch_hidden_data = batch_hidden.mutable_data(); + auto hidden_data = hidden.mutable_data(); + std::vector batch_gate_out{ + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0}; + std::vector batch_reset_hidden_prev_out{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector batch_hidden_out{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector hidden_out{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + LOG(INFO) << "output: "; + for (int i = 0; i < batch_gate.dims().production(); i++) { + LOG(INFO) << batch_gate_data[i]; + EXPECT_NEAR(batch_gate_data[i], batch_gate_out[i], 1e-3); + } + for (int i = 0; i < batch_reset_hidden_prev.dims().production(); i++) { + LOG(INFO) << batch_reset_hidden_prev_data[i]; + EXPECT_NEAR( + batch_reset_hidden_prev_data[i], batch_reset_hidden_prev_out[i], 1e-3); + } + for (int i = 0; i < batch_hidden.dims().production(); i++) { + LOG(INFO) << batch_hidden_data[i]; + EXPECT_NEAR(batch_hidden_data[i], batch_hidden_out[i], 1e-3); + } + for (int i = 0; i < hidden.dims().production(); i++) { + LOG(INFO) << hidden_data[i]; + EXPECT_NEAR(hidden_data[i], hidden_out[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(gru, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/relu_compute_test.cc b/lite/kernels/x86/relu_compute_test.cc index ec446de73f..37ed6db7f9 100644 --- a/lite/kernels/x86/relu_compute_test.cc +++ b/lite/kernels/x86/relu_compute_test.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/x86/relu_compute.h" #include #include #include #include "lite/core/op_registry.h" +#include "lite/kernels/x86/activation_compute.h" namespace paddle { namespace lite { @@ -64,6 +64,8 @@ TEST(relu_x86, run_test) { LOG(INFO) << "output: "; for (int i = 0; i < out.dims().production(); i++) { LOG(INFO) << out_data[i]; + int sign = i % 2 == 0 ? 1 : 0; + ASSERT_EQ(out_data[i], i * sign); } } diff --git a/lite/kernels/x86/sequence_expand_as_compute.cc b/lite/kernels/x86/sequence_expand_as_compute.cc new file mode 100644 index 0000000000..4e03096976 --- /dev/null +++ b/lite/kernels/x86/sequence_expand_as_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_expand_as_compute.h" + +REGISTER_LITE_KERNEL(sequence_expand_as, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceExpandAsCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_expand_as_compute.h b/lite/kernels/x86/sequence_expand_as_compute.h new file mode 100644 index 0000000000..16759c1b9f --- /dev/null +++ b/lite/kernels/x86/sequence_expand_as_compute.h @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +using Tensor = lite::Tensor; + +template +struct SequenceExpandFunctor { + void operator()(const Tensor &x, + const std::vector &ref_lod, /*expand referenced lod*/ + Tensor *out) { + int64_t hight = x.dims()[0]; + int64_t width = x.data_size() / hight; + + const T *in_data = x.data(); + T *out_data = out->mutable_data(); + + for (int h_id = 0; h_id < hight; ++h_id) { + size_t span = ref_lod[h_id + 1] - ref_lod[h_id]; + if (span == 0) continue; + const T *src = in_data + h_id * width; + for (int64_t w_id = 0; w_id < width; ++w_id) { + T ele = src[w_id]; + size_t offset = ref_lod[h_id] * width; + for (size_t k = 0; k < span; ++k) { + out_data[offset + k * width + w_id] = ele; + } + } + } + } +}; + +template +class SequenceExpandAsCompute + : public KernelLite { + public: + void Run() override { + auto ¶m = *param_.get_mutable(); + + auto *x = param.x; + auto *y = param.y; + auto *out = param.out; + + auto &y_lod = y->lod(); + CHECK_EQ(y_lod.size(), 1); + CHECK_GT(y_lod[0].size(), 1); + + out->mutable_data(); + + SequenceExpandFunctor seq_espand_functor; + seq_espand_functor(*x, y_lod[0], out); + } +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_expand_as_compute_test.cc b/lite/kernels/x86/sequence_expand_as_compute_test.cc new file mode 100644 index 0000000000..d49fdbb7a6 --- /dev/null +++ b/lite/kernels/x86/sequence_expand_as_compute_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_expand_as_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(sequence_expand_as_x86, retrive_op) { + auto sequence_expand_as = + KernelRegistry::Global().Create( + "sequence_expand_as"); + ASSERT_FALSE(sequence_expand_as.empty()); + ASSERT_TRUE(sequence_expand_as.front()); +} + +TEST(sequence_expand_as_x86, init) { + SequenceExpandAsCompute sequence_expand_as; + ASSERT_EQ(sequence_expand_as.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_expand_as.target(), TARGET(kX86)); +} + +TEST(sequence_expand_as_x86, run_test) { + lite::Tensor x, y, out; + std::vector x_shape{4, 1}; + x.Resize(lite::DDim(x_shape)); + std::vector y_shape{1, 5}; + y.Resize(lite::DDim(y_shape)); + std::vector out_shape{8, 1}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto y_data = y.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < y.dims().production(); i++) { + y_data[i] = static_cast(i); + } + + std::vector> lod{{0, 3, 6, 7, 8}}; + y.set_lod(lod); + // MulCompute mul; + SequenceExpandAsCompute sequence_expand_as; + operators::SequenceExpandAsParam param; + + param.x = &x; + param.y = &y; + param.out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + sequence_expand_as.SetContext(std::move(ctx)); + sequence_expand_as.SetParam(param); + sequence_expand_as.Run(); + auto out_data = out.mutable_data(); + + int index = 1; + int lod_sum = lod[0][index]; + LOG(INFO) << "output: "; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + if (i >= lod_sum) { + index++; + lod_sum = lod[0][index]; + } + ASSERT_EQ(out_data[i], x_data[index - 1]); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_expand_as, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index b992b12831..44c42962f5 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -71,6 +71,7 @@ add_operator(roi_align_op basic SRCS roi_align_op.cc DEPS ${op_DEPS}) add_operator(box_clip_op basic SRCS box_clip_op.cc DEPS ${op_DEPS}) add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS}) add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) +add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS ${op_DEPS}) add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 18d7c412fe..c1f2b12cb4 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -682,6 +682,12 @@ struct SequenceExpandParam { int ref_level{-1}; }; +struct SequenceExpandAsParam { + const lite::Tensor* x{nullptr}; + const lite::Tensor* y{nullptr}; + lite::Tensor* out{nullptr}; +}; + struct ReduceMaxParam { const lite::Tensor* X{}; lite::Tensor* Out{}; diff --git a/lite/operators/sequence_expand_as_op.cc b/lite/operators/sequence_expand_as_op.cc new file mode 100644 index 0000000000..22a4743103 --- /dev/null +++ b/lite/operators/sequence_expand_as_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_expand_as_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceExpandAsOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x) + CHECK_OR_FALSE(param_.y) + CHECK_OR_FALSE(param_.out) + + auto x_dims = param_.x->dims(); + CHECK_EQ_OR_FALSE(x_dims.size(), 2) + auto y_lod = param_.y->lod(); + CHECK_EQ_OR_FALSE(y_lod.size(), 1) + CHECK_EQ_OR_FALSE(static_cast(x_dims[0]), y_lod[0].size() - 1) + + return true; +} + +bool SequenceExpandAsOpLite::InferShape() const { + auto x_dims = param_.x->dims(); + auto y_lod = param_.y->lod(); + auto out_dims = x_dims; + + int64_t out_first_dim = 0; + if (y_lod[0].size() <= 1) { + out_first_dim = x_dims[0]; + } else { + for (size_t i = 1; i < y_lod[0].size(); ++i) { + out_first_dim += (y_lod[0][i] - y_lod[0][i - 1]); + } + } + out_dims[0] = out_first_dim; + + param_.out->Resize(out_dims); + param_.out->set_lod(y_lod); + + return true; +} + +bool SequenceExpandAsOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto y = op_desc.Input("Y").front(); + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.y = scope->FindVar(y)->GetMutable(); + param_.out = scope->FindVar(out)->GetMutable(); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_expand_as, + paddle::lite::operators::SequenceExpandAsOpLite) diff --git a/lite/operators/sequence_expand_as_op.h b/lite/operators/sequence_expand_as_op.h new file mode 100644 index 0000000000..2eae8a26da --- /dev/null +++ b/lite/operators/sequence_expand_as_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceExpandAsOpLite : public OpLite { + public: + SequenceExpandAsOpLite() {} + explicit SequenceExpandAsOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_expand_as"; } + + private: + mutable SequenceExpandAsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab