diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index ff064fb2ee93fc540e932da36fb07bb78eef989a..0d11b47db6a7f767f8cd032877d8647b0872b8d4 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -47,4 +47,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, paddle::lite::mir::ConvActivationFusePass) .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}) .BindKernel("conv2d"); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index d9d9c1bbf55bd33c31aa9a22de934d4eae8657c6..5ab5f8c0a4797e51cce656de43883a68d4931e9b 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -45,4 +45,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) .BindTargets({TARGET(kAny)}) - .ExcludeTargets({TARGET(kX86)}); + .ExcludeTargets({TARGET(kX86), TARGET(kXPU)}); diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc index fd9aadc5d01c2cb3b6c7a3e888503072a0798725..b1b492ce030c7a46d8b23936c1661f3d743eb9cb 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -46,4 +46,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, paddle::lite::mir::ConvElementwiseFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index af66f5ab66bd09907cb9d28f00f17d983e54c252..e4391cd24287cafe457074733ba73208288c3375 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -35,4 +35,5 @@ void ElementwiseAddActivationFusePass::Apply( REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, paddle::lite::mir::ElementwiseAddActivationFusePass) .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}) .BindKernel("fusion_elementwise_add_activation"); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index ed10f06f5651f4000485279d682689101d80aa5a..7fc449219251bbd7e639e8092099f43fe8eca626 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -33,4 +33,5 @@ void FcFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}) .BindKernel("fc"); diff --git a/lite/kernels/npu/bridges/CMakeLists.txt b/lite/kernels/npu/bridges/CMakeLists.txt index 032de819743f4aba02e442dd71c26b950d1435b6..1f4a98c048b48efd011323eac27f71865846dc87 100644 --- a/lite/kernels/npu/bridges/CMakeLists.txt +++ b/lite/kernels/npu/bridges/CMakeLists.txt @@ -19,6 +19,7 @@ lite_cc_library(npu_bridge_split_op SRCS split_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_concat_op SRCS concat_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_shuffle_channel_op SRCS shuffle_channel_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_pad2d_op SRCS pad2d_op.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_square_op SRCS square_op.cc DEPS ${npu_bridge_deps}) set(npu_bridges npu_bridge_registry @@ -39,6 +40,7 @@ set(npu_bridges npu_bridge_concat_op npu_bridge_shuffle_channel_op npu_bridge_pad2d_op + npu_bridge_square_op CACHE INTERNAL "npu_bridges") set(npu_bridge_test_deps ${npu_bridges} ${npu_kernels} ${ops}) @@ -60,5 +62,6 @@ lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc test_helper.cc DEPS lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) +lite_cc_test(test_npu_bridge_square_op SRCS square_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps}) message(STATUS "+++++ npu_bridges: ${npu_bridges}") diff --git a/lite/kernels/npu/bridges/conv_op.cc b/lite/kernels/npu/bridges/conv_op.cc index 99701d4ff38b6044aabc8e019b65966af50f1a1f..8dc9ab1f0f8a1e63c52b2406117fc34477e71490 100644 --- a/lite/kernels/npu/bridges/conv_op.cc +++ b/lite/kernels/npu/bridges/conv_op.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "lite/operators/conv_op.h" #include "lite/backends/npu/builder.h" #include "lite/kernels/npu/bridges/registry.h" @@ -53,15 +54,27 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto dilations = op_info->GetAttr>("dilations"); auto fuse_relu = op_info->GetAttr("fuse_relu"); CHECK_EQ(strides.size(), 2L); - CHECK_EQ(paddings.size(), 4L); CHECK_EQ(dilations.size(), 2L); - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); - if (!pad_equal) { - LOG(FATAL) << "This pad not support ! " << paddings[0] << ", " - << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; + if (paddings.size() == 2L) { + for (size_t i = 0; i < strides.size(); ++i) { + int copy_pad = *(paddings.begin() + 2 * i); + paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); + } + } + CHECK_EQ(paddings.size(), 4L) + << "Paddings size should be the same or twice as the input size."; + + std::string padding_algorithm(""); + if (op_info->HasAttr("padding_algorithm")) { + padding_algorithm = op_info->GetAttr("padding_algorithm"); } + operators::UpdatePaddingAndDilation(&paddings, + &dilations, + strides, + padding_algorithm, + input_dims, + filter_dims); // check depthwise mode, and decide whether use ConvolutionDepthwise Op bool use_depthwise_conv = @@ -141,7 +154,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, depthwise_conv_node->set_attr_pad_mode(5); // VALID depthwise_conv_node->set_attr_group(groups); depthwise_conv_node->set_attr_pad(ge::AttrValue::LIST_INT( - {paddings[0], paddings[0], paddings[2], paddings[2]})); + {paddings[0], paddings[1], paddings[2], paddings[3]})); depthwise_conv_node->set_attr_dilation( ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); depthwise_conv_node->set_attr_stride( diff --git a/lite/kernels/npu/bridges/conv_transpose_op.cc b/lite/kernels/npu/bridges/conv_transpose_op.cc index 9e41829dc69b7b3247b1e487ebc9b6c59e1c6361..6eff4cb2d28d64098186dfb50a457a8828b8eb61 100644 --- a/lite/kernels/npu/bridges/conv_transpose_op.cc +++ b/lite/kernels/npu/bridges/conv_transpose_op.cc @@ -45,18 +45,21 @@ node_map_type ConvTransposeConverter( auto dilations = op_info->GetAttr>("dilations"); auto fuse_relu = op_info->GetAttr("fuse_relu"); CHECK_EQ(strides.size(), 2L); - CHECK_EQ(paddings.size(), 4L); CHECK_EQ(dilations.size(), 2L); + if (paddings.size() == 2L) { + for (size_t i = 0; i < 2L; ++i) { + int copy_pad = *(paddings.begin() + 2 * i); + paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); + } + } + CHECK_EQ(paddings.size(), 4L) + << "Paddings size should be the same or twice as the input size."; + // create deconv node auto conv_transpose_node = std::make_shared(unique_op_type); - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); - if (!pad_equal) { - LOG(FATAL) << "This pad not support ! " << paddings[0] << ", " - << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; - } + // create input sizes node to describe the dimensions of input tensor std::vector output_shape; output_shape.push_back(input_shape[0]); @@ -91,7 +94,7 @@ node_map_type ConvTransposeConverter( conv_transpose_node->set_attr_pad_mode(0); // NOTSET conv_transpose_node->set_attr_group(groups); conv_transpose_node->set_attr_pad(ge::AttrValue::LIST_INT( - {paddings[0], paddings[0], paddings[1], paddings[1]})); + {paddings[0], paddings[1], paddings[2], paddings[3]})); conv_transpose_node->set_attr_dilation( ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); conv_transpose_node->set_attr_stride( diff --git a/lite/kernels/npu/bridges/square_op.cc b/lite/kernels/npu/bridges/square_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ca91adba0a8b24e6559599cb5952f8b47722ba3 --- /dev/null +++ b/lite/kernels/npu/bridges/square_op.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/npu/builder.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace npu { +namespace bridges { + +node_map_type SquareConverter(const std::shared_ptr square_op, + const node_map_type& inputs_map) { + auto scope = square_op->scope(); + auto op_info = square_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = lite::npu::UniqueName(op_type); + LOG(INFO) << "[NPU] Converting " + op_type + "..."; + + std::shared_ptr square_node = + std::make_shared(unique_op_type); + + auto x_var_name = op_info->Input("X").front(); + + CHECK(inputs_map.count(x_var_name)); + square_node->set_input_x(*inputs_map.at(x_var_name)); + + lite::npu::OpList::Global().add(inputs_map.at(x_var_name)); + lite::npu::OpList::Global().add(square_node); + + node_map_type outputs_map; + outputs_map[op_info->Output("Out").front()] = square_node; + return outputs_map; +} + +} // namespace bridges +} // namespace npu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(square, + paddle::lite::kernels::npu::bridges::SquareConverter); diff --git a/lite/kernels/npu/bridges/square_op_test.cc b/lite/kernels/npu/bridges/square_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d715c11430096a0b6503fbe6047a40c3c29ba8f5 --- /dev/null +++ b/lite/kernels/npu/bridges/square_op_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/npu/bridges/test_helper.h" +#include "lite/operators/activation_ops.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace npu { +namespace bridges { + +template +void square_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + + auto x = scope->FindTensor("x"); + auto out = scope->FindMutableTensor("out_ref"); + out->Resize(x->dims()); + auto x_data = x->data(); + auto out_data = out->mutable_data(); + + for (size_t i = 0; i < x->numel(); i++) { + out_data[i] = x_data[i] * x_data[i]; + } +} + +void test_square(const std::vector& input_shape) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + auto* x = scope.NewTensor(x_var_name); + auto* out = scope.NewTensor(out_var_name); + auto* out_ref = scope.NewTensor(out_ref_var_name); + x->Resize(input_shape); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("square"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + + // create and convert op to NPU model, then run it on NPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name}); + + // execute reference implementation and save to output tensor + square_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); + } +} + +TEST(NPUBridges, square) { + test_square({2}); + test_square({2, 3}); + test_square({1, 2, 3, 4}); + test_square({5, 6, 7, 8}); +} + +} // namespace bridges +} // namespace npu +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_OP(square); +USE_NPU_BRIDGE(square); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index b2b3bb79a4836dc07196899466697da4887fe5a6..a1d36151153741e7c4413e9ce7a1726729096bd3 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -29,6 +29,7 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) +add_kernel(gather_compute_x86 X86 basic SRCS gather_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) @@ -65,6 +66,7 @@ add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kerne lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) +lite_cc_test(test_gather_compute_x86 SRCS gather_compute_test.cc DEPS gather_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86) diff --git a/lite/kernels/x86/gather_compute.cc b/lite/kernels/x86/gather_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..836f336271ef53c338cca89855b48c94c778cc54 --- /dev/null +++ b/lite/kernels/x86/gather_compute.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gather_compute.h" + +typedef paddle::lite::kernels::x86::GatherCompute GatherInt32; +typedef paddle::lite::kernels::x86::GatherCompute GatherInt64; + +REGISTER_LITE_KERNEL(gather, kX86, kFloat, kNCHW, GatherInt32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(gather, kX86, kFloat, kNCHW, GatherInt64, int64_in) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/gather_compute.h b/lite/kernels/x86/gather_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..6ee270647f8fb7d7ec540047cd4d546a7eb89ce8 --- /dev/null +++ b/lite/kernels/x86/gather_compute.h @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/api/paddle_place.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/data_type.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/** + * A thin wrapper for gathering on cpu tensor + * Return a new tensor from source tensor, gathered according to index + * input[src]: type-T source Tensor + * input[index]: type-IndexT index Tensor (1-D) + * return: output tensor + */ +template +void CPUGather(const lite::Tensor* src, + const lite::Tensor* index, + lite::Tensor* output) { + // check index of shape 1-D + if (index->dims().size() == 2) { + CHECK(index->dims()[1] == 1) << "Index(Input)'s dimension[1] should be 1 " + "when Index(input)'s dimension's size " + "equal to 2 in Gather(Op)."; + } else { + CHECK(index->dims().size() == 1) + << "Index(Input)'s dimension's size() should be 1 or 2 in Gather(Op)."; + } + int64_t index_size = index->dims()[0]; + + auto src_dims = src->dims(); + + const T* p_src = src->data(); + const IndexT* p_index = index->data(); + T* p_output = output->mutable_data(); + + // slice size + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const size_t slice_bytes = slice_size * sizeof(T); + for (int64_t i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); + } +} + +template +class GatherCompute : public KernelLite { + public: + using param_t = operators::GatherParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + auto x = param.X; + auto index = param.Index; + auto out = param.Out; + + out->mutable_data(); + if (x->dims().production() == 0) return; + /* + * Since there's no type defined for lite::Tensor in Paddle-Lite, then + * convert the Index's value to float which must be int32_t or int64_t and + * this supposes to cause no precision difference during inference just for + * now. + * Alternatively, if define the Tensor's type during registering, may cause + * a redefinition error. + */ + CPUGather(x, index, out); + } + + virtual ~GatherCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/gather_compute_test.cc b/lite/kernels/x86/gather_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..286dfcb08a0c2c7bc038e0ad3b5673bd7c0f8b19 --- /dev/null +++ b/lite/kernels/x86/gather_compute_test.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gather_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(gather_x86, retrive_op) { + auto gather = + KernelRegistry::Global().Create( + "gather"); + ASSERT_FALSE(gather.empty()); + int cnt = 0; + for (auto item = gather.begin(); item != gather.end(); ++item) { + cnt++; + ASSERT_TRUE(*item); + } + ASSERT_EQ(cnt, 2); +} + +TEST(gather_x86, int32_init) { + GatherCompute gather; + ASSERT_EQ(gather.precision(), PRECISION(kFloat)); + ASSERT_EQ(gather.target(), TARGET(kX86)); +} + +TEST(gather_x86, int64_init) { + GatherCompute gather; + ASSERT_EQ(gather.precision(), PRECISION(kFloat)); + ASSERT_EQ(gather.target(), TARGET(kX86)); +} + +template +void test_case_1dims() { + lite::Tensor x, index, out; + std::vector x_shape{10}; + x.Resize(lite::DDim(x_shape)); + std::vector index_shape{3}; + index.Resize(lite::DDim(index_shape)); + std::vector out_shape{3}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto index_data = index.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + std::vector index_value{1, 3, 5}; + for (int i = 0; i < index.dims().production(); ++i) { + index_data[i] = static_cast(index_value[i]); + } + + GatherCompute gather; + operators::GatherParam param; + + param.X = &x; + param.Index = &index; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gather.SetContext(std::move(ctx)); + gather.SetParam(param); + gather.Run(); + + std::vector ref_data{1, 3, 5}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +template +void test_case_2dims() { + lite::Tensor x, index, out; + std::vector x_shape{10, 20}; + x.Resize(lite::DDim(x_shape)); + std::vector index_shape{3}; + index.Resize(lite::DDim(index_shape)); + std::vector out_shape{3, 20}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto index_data = index.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + std::vector index_value{1, 3, 5}; + for (int i = 0; i < index.dims().production(); ++i) { + index_data[i] = static_cast(index_value[i]); + } + + GatherCompute gather; + operators::GatherParam param; + + param.X = &x; + param.Index = &index; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gather.SetContext(std::move(ctx)); + gather.SetParam(param); + gather.Run(); + + std::vector ref_data(60); + for (int i = 0; i < 20; ++i) { + ref_data[i] = static_cast(20 + i); + } + for (int i = 20; i < 40; ++i) { + ref_data[i] = static_cast(40 + i); + } + for (int i = 40; i < 60; ++i) { + ref_data[i] = static_cast(60 + i); + } + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +TEST(gather_x86, run_test_1dims) { + test_case_1dims(); + test_case_1dims(); +} + +TEST(gather_x86, run_test_2dims) { + test_case_2dims(); + test_case_2dims(); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(gather, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(gather, kX86, kFloat, kNCHW, int64_in); diff --git a/lite/kernels/xpu/bridges/conv_op.cc b/lite/kernels/xpu/bridges/conv_op.cc index 9acb0e4e3d186ec4cda826ad22323ebbd1b38779..d6fc806ad4541a735ea4ef6eff292076836ac5e7 100644 --- a/lite/kernels/xpu/bridges/conv_op.cc +++ b/lite/kernels/xpu/bridges/conv_op.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "lite/operators/conv_op.h" #include "lite/backends/xpu/builder.h" #include "lite/kernels/xpu/bridges/registry.h" @@ -47,8 +48,28 @@ node_map_type ConvConverter(const std::shared_ptr op, auto dilations = op_info->GetAttr>("dilations"); auto fuse_relu = op_info->GetAttr("fuse_relu"); CHECK_EQ(strides.size(), 2L); - CHECK_EQ(paddings.size(), 4L); CHECK_EQ(dilations.size(), 2L); + + if (paddings.size() == 2L) { + for (size_t i = 0; i < strides.size(); ++i) { + int copy_pad = *(paddings.begin() + 2 * i); + paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); + } + } + CHECK_EQ(paddings.size(), 4L) + << "Paddings size should be the same or twice as the input size."; + + std::string padding_algorithm(""); + if (op_info->HasAttr("padding_algorithm")) { + padding_algorithm = op_info->GetAttr("padding_algorithm"); + } + operators::UpdatePaddingAndDilation(&paddings, + &dilations, + strides, + padding_algorithm, + input_dims, + filter_dims); + std::vector output_shape({bs, oc}); for (size_t i = 0; i < 2; i++) { const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1; @@ -59,12 +80,6 @@ node_map_type ConvConverter(const std::shared_ptr op, } DDim output_dims(output_shape); - bool pads_equal = - (paddings[0] == paddings[1]) && (paddings[2] == paddings[3]); - if (!pads_equal) { - LOG(FATAL) << "Padding requies pad_top==pad_bottom and pad_lef==pad_right."; - } - // check context CHECK(graph_ctx != nullptr); CHECK(graph_ctx->builder != nullptr); diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index d9c0ecb4fd8457782ac90850b8b6a002c7dfcffe..6dab55ff3b6c55e7763484d78c6c36bf85017128 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -52,34 +52,6 @@ inline int ConvOutputSize(int input_size, return output_size; } -inline void UpdatePaddingAndDilation(std::vector* paddings, - std::vector* dilations, - const std::vector& strides, - const std::string padding_algorithm, - const lite::DDim data_dims, - const lite::DDim& ksize) { - // when padding_desc is "VALID" or "SAME" - if (padding_algorithm == "SAME") { - for (size_t i = 0; i < strides.size(); ++i) { - int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; - int pad_sum = std::max( - (out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2], - (int64_t)0); - int pad_0 = pad_sum / 2; - int pad_1 = pad_sum - pad_0; - // pad - *(paddings->begin() + i * 2) = pad_0; - *(paddings->begin() + i * 2 + 1) = pad_1; - // dilation - *(dilations->begin() + i) = 1; - } - } else if (padding_algorithm == "VALID") { - for (auto& it : *paddings) { - it = 0; - } - } -} - bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 24848803fb7ea2139f87aa5b5f2119592dc00084..3ab34bc1d0bd631b0641cebd3db29cfff9316bb0 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -137,6 +137,34 @@ class ConvOpLite : public OpLite { std::string padding_algorithm_{""}; }; +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize) { + // when padding_desc is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (size_t i = 0; i < strides.size(); ++i) { + int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; + int pad_sum = std::max( + (out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2], + (int64_t)0); + int pad_0 = pad_sum / 2; + int pad_1 = pad_sum - pad_0; + // pad + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + // dilation + *(dilations->begin() + i) = 1; + } + } else if (padding_algorithm == "VALID") { + for (auto& it : *paddings) { + it = 0; + } + } +} + } // namespace operators } // namespace lite } // namespace paddle