diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index c8da4fc6698df2ce0f5e009070ef61f351b28f5c..1de257a603af016b899a07d2022d81422f64e8dd 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -9,6 +9,7 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) # lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op) # lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) @@ -45,6 +46,7 @@ add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kerne lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) +lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86) lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86) lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) diff --git a/lite/kernels/x86/fill_constant_batch_size_like_compute.cc b/lite/kernels/x86/fill_constant_batch_size_like_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..03f72e3056094ea1262a90057de1486dd349d554 --- /dev/null +++ b/lite/kernels/x86/fill_constant_batch_size_like_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h" + +REGISTER_LITE_KERNEL( + fill_constant_batch_size_like, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::FillConstantBatchSizeLikeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/fill_constant_batch_size_like_compute.h b/lite/kernels/x86/fill_constant_batch_size_like_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..1f157a0db77dae8b6cb7eca936eb942517986331 --- /dev/null +++ b/lite/kernels/x86/fill_constant_batch_size_like_compute.h @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class FillConstantBatchSizeLikeCompute + : public KernelLite { + public: + using param_t = operators::FillConstantBatchSizeLikeParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto* out = param.Out; + auto* in = param.Input; + if (in->lod().size() && param.input_dim_idx == 0) { + // set the correct batch size for the LoDTensor. + auto odims = out->dims(); + int output_dim_idx = param.output_dim_idx; + odims[output_dim_idx] = static_cast(in->lod().back().size()) - 1; + } + auto value = param.value; + + paddle::lite::x86::math::SetConstant setter; + Context ctx; + setter(ctx, out, static_cast(value)); + } + + virtual ~FillConstantBatchSizeLikeCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/fill_constant_batch_size_like_compute_test.cc b/lite/kernels/x86/fill_constant_batch_size_like_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a071aec16dafeff294c097ffe57d64a983e0af4e --- /dev/null +++ b/lite/kernels/x86/fill_constant_batch_size_like_compute_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(fill_constant_batch_size_like_x86, retrive_op) { + auto fill_constant_batch_size_like = + KernelRegistry::Global().Create( + "fill_constant_batch_size_like"); + ASSERT_FALSE(fill_constant_batch_size_like.empty()); + ASSERT_TRUE(fill_constant_batch_size_like.front()); +} + +TEST(fill_constant_batch_size_like_x86, init) { + lite::kernels::x86::FillConstantBatchSizeLikeCompute + fill_constant_batch_size_like; + ASSERT_EQ(fill_constant_batch_size_like.precision(), PRECISION(kFloat)); + ASSERT_EQ(fill_constant_batch_size_like.target(), TARGET(kX86)); +} + +TEST(fill_constant_batch_size_like_x86, run_test) { + lite::Tensor input; + lite::Tensor out; + std::vector input_shape{219, 232}; + input.Resize(input_shape); + std::vector out_shape{219, 132, 7}; + + auto input_data = input.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < input.dims().production(); ++i) { + input_data[i] = static_cast(i); + } + + FillConstantBatchSizeLikeCompute fill_constant_batch_size_like; + operators::FillConstantBatchSizeLikeParam param; + param.Input = &input; + param.Out = &out; + std::vector shape{-1, 132, 7}; + float value = 3.5; + param.shape = shape; + param.value = value; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + fill_constant_batch_size_like.SetParam(param); + fill_constant_batch_size_like.Run(); + + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(fill_constant_batch_size_like, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 44c42962f51550dda75c78cc2d61961d5dd105ea..7b456222fa2e8f8b53cea3258f5b174832522695 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -25,6 +25,7 @@ add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DE add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS}) add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS}) add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS}) +add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS}) #add_operator(sgd_op basic SRCS sgd_op.cc DEPS ${op_DEPS}) add_operator(uniform_random_op basic SRCS uniform_random_op.cc DEPS ${op_DEPS}) add_operator(power_op basic SRCS power_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/fill_constant_batch_size_like_op.cc b/lite/operators/fill_constant_batch_size_like_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a28959e48a2560938eb19bc2868c473ac8fde03 --- /dev/null +++ b/lite/operators/fill_constant_batch_size_like_op.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fill_constant_batch_size_like_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool FillConstantBatchSizeLikeOp::CheckShape() const { + CHECK_OR_FALSE(param_.Input); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool FillConstantBatchSizeLikeOp::InferShape() const { + auto shape = param_.shape; + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { + return static_cast(a); + }); + lite::DDim output_dim(shape_int64); + + int input_dim_idx = param_.input_dim_idx; + int output_dim_idx = param_.output_dim_idx; + + output_dim[output_dim_idx] = param_.Input->dims()[input_dim_idx]; + param_.Out->Resize(output_dim); + return true; +} + +bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto Input = op_desc.Input("X").front(); + auto Out = op_desc.Output("Out").front(); + param_.Input = scope->FindVar(Input)->GetMutable(); + param_.Out = scope->FindVar(Out)->GetMutable(); + param_.shape = op_desc.GetAttr>("shape"); + param_.input_dim_idx = op_desc.GetAttr("input_dim_idx"); + param_.output_dim_idx = op_desc.GetAttr("output_dim_idx"); + param_.dtype = op_desc.GetAttr("dtype"); + param_.value = op_desc.GetAttr("value"); + CHECK(param_.Input); + CHECK(param_.Out); + + return true; +} + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ + +REGISTER_LITE_OP(fill_constant_batch_size_like, + paddle::lite::operators::FillConstantBatchSizeLikeOp); diff --git a/lite/operators/fill_constant_batch_size_like_op.h b/lite/operators/fill_constant_batch_size_like_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b073ba8379e5e52fcd3a2d0ee28aaaf5ceaea678 --- /dev/null +++ b/lite/operators/fill_constant_batch_size_like_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FillConstantBatchSizeLikeOp : public OpLite { + public: + FillConstantBatchSizeLikeOp() {} + + explicit FillConstantBatchSizeLikeOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { + return "fill_constant_batch_size_like"; + } + + private: + mutable FillConstantBatchSizeLikeParam param_; +}; + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index c1f2b12cb46f6586f96550cf1fc973fed7bab25c..f7e0bcd4e498741ab911c5daf4358a72b646f009 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -371,6 +371,17 @@ struct FillConstantParam { lite::Tensor* Out{}; }; +struct FillConstantBatchSizeLikeParam { + lite::Tensor* Input; + lite::Tensor* Out; + + std::vector shape; + int input_dim_idx{0}; + int output_dim_idx{0}; + int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; + float value{0.0f}; +}; + // struct FakeQuantizeMovingAvgMaxAbsParam { const lite::Tensor* x{};