提交 7100f6a8 编写于 作者: qnqinan's avatar qnqinan

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into develop

......@@ -47,4 +47,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.BindKernel("conv2d");
......@@ -45,4 +45,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kX86)});
.ExcludeTargets({TARGET(kX86), TARGET(kXPU)});
......@@ -46,4 +46,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass,
paddle::lite::mir::ConvElementwiseFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)});
......@@ -35,4 +35,5 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.BindKernel("fusion_elementwise_add_activation");
......@@ -33,4 +33,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.BindKernel("fc");
......@@ -19,6 +19,7 @@ lite_cc_library(npu_bridge_split_op SRCS split_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_concat_op SRCS concat_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_shuffle_channel_op SRCS shuffle_channel_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_pad2d_op SRCS pad2d_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_square_op SRCS square_op.cc DEPS ${npu_bridge_deps})
set(npu_bridges
npu_bridge_registry
......@@ -39,6 +40,7 @@ set(npu_bridges
npu_bridge_concat_op
npu_bridge_shuffle_channel_op
npu_bridge_pad2d_op
npu_bridge_square_op
CACHE INTERNAL "npu_bridges")
set(npu_bridge_test_deps ${npu_bridges} ${npu_kernels} ${ops})
......@@ -60,5 +62,6 @@ lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc test_helper.cc DEPS
lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_pad2d_op SRCS pad2d_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_square_op SRCS square_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
message(STATUS "+++++ npu_bridges: ${npu_bridges}")
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_op.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
......@@ -53,15 +54,27 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(paddings.size(), 4L);
CHECK_EQ(dilations.size(), 2L);
bool pad_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
if (!pad_equal) {
LOG(FATAL) << "This pad not support ! " << paddings[0] << ", "
<< paddings[1] << ", " << paddings[2] << ", " << paddings[3];
if (paddings.size() == 2L) {
for (size_t i = 0; i < strides.size(); ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
}
CHECK_EQ(paddings.size(), 4L)
<< "Paddings size should be the same or twice as the input size.";
std::string padding_algorithm("");
if (op_info->HasAttr("padding_algorithm")) {
padding_algorithm = op_info->GetAttr<std::string>("padding_algorithm");
}
operators::UpdatePaddingAndDilation(&paddings,
&dilations,
strides,
padding_algorithm,
input_dims,
filter_dims);
// check depthwise mode, and decide whether use ConvolutionDepthwise Op
bool use_depthwise_conv =
......@@ -141,7 +154,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
depthwise_conv_node->set_attr_pad_mode(5); // VALID
depthwise_conv_node->set_attr_group(groups);
depthwise_conv_node->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[0], paddings[2], paddings[2]}));
{paddings[0], paddings[1], paddings[2], paddings[3]}));
depthwise_conv_node->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
depthwise_conv_node->set_attr_stride(
......
......@@ -45,18 +45,21 @@ node_map_type ConvTransposeConverter(
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(paddings.size(), 4L);
CHECK_EQ(dilations.size(), 2L);
if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
}
CHECK_EQ(paddings.size(), 4L)
<< "Paddings size should be the same or twice as the input size.";
// create deconv node
auto conv_transpose_node =
std::make_shared<ge::op::Deconvolution>(unique_op_type);
bool pad_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
if (!pad_equal) {
LOG(FATAL) << "This pad not support ! " << paddings[0] << ", "
<< paddings[1] << ", " << paddings[2] << ", " << paddings[3];
}
// create input sizes node to describe the dimensions of input tensor
std::vector<int32_t> output_shape;
output_shape.push_back(input_shape[0]);
......@@ -91,7 +94,7 @@ node_map_type ConvTransposeConverter(
conv_transpose_node->set_attr_pad_mode(0); // NOTSET
conv_transpose_node->set_attr_group(groups);
conv_transpose_node->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[0], paddings[1], paddings[1]}));
{paddings[0], paddings[1], paddings[2], paddings[3]}));
conv_transpose_node->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
conv_transpose_node->set_attr_stride(
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
node_map_type SquareConverter(const std::shared_ptr<lite::OpLite> square_op,
const node_map_type& inputs_map) {
auto scope = square_op->scope();
auto op_info = square_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
std::shared_ptr<ge::op::Square> square_node =
std::make_shared<ge::op::Square>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
CHECK(inputs_map.count(x_var_name));
square_node->set_input_x(*inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(square_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = square_node;
return outputs_map;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(square,
paddle::lite::kernels::npu::bridges::SquareConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void square_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
for (size_t i = 0; i < x->numel(); i++) {
out_data[i] = x_data[i] * x_data[i];
}
}
void test_square(const std::vector<int64_t>& input_shape) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("square");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
square_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, square) {
test_square({2});
test_square({2, 3});
test_square({1, 2, 3, 4});
test_square({5, 6, 7, 8});
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(square);
USE_NPU_BRIDGE(square);
......@@ -29,6 +29,7 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp
# lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
# lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
add_kernel(gather_compute_x86 X86 basic SRCS gather_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86)
# lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
# lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
......@@ -65,6 +66,7 @@ add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kerne
lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_gather_compute_x86 SRCS gather_compute_test.cc DEPS gather_compute_x86)
lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86)
lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86)
lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/gather_compute.h"
typedef paddle::lite::kernels::x86::GatherCompute<float, int32_t> GatherInt32;
typedef paddle::lite::kernels::x86::GatherCompute<float, int64_t> GatherInt64;
REGISTER_LITE_KERNEL(gather, kX86, kFloat, kNCHW, GatherInt32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(gather, kX86, kFloat, kNCHW, GatherInt64, int64_in)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "lite/fluid/data_type.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
/**
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T, typename IndexT = int>
void CPUGather(const lite::Tensor* src,
const lite::Tensor* index,
lite::Tensor* output) {
// check index of shape 1-D
if (index->dims().size() == 2) {
CHECK(index->dims()[1] == 1) << "Index(Input)'s dimension[1] should be 1 "
"when Index(input)'s dimension's size "
"equal to 2 in Gather(Op).";
} else {
CHECK(index->dims().size() == 1)
<< "Index(Input)'s dimension's size() should be 1 or 2 in Gather(Op).";
}
int64_t index_size = index->dims()[0];
auto src_dims = src->dims();
const T* p_src = src->data<T>();
const IndexT* p_index = index->data<IndexT>();
T* p_output = output->mutable_data<T>();
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
}
template <typename T, typename IndexT>
class GatherCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::GatherParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto index = param.Index;
auto out = param.Out;
out->mutable_data<T>();
if (x->dims().production() == 0) return;
/*
* Since there's no type defined for lite::Tensor in Paddle-Lite, then
* convert the Index's value to float which must be int32_t or int64_t and
* this supposes to cause no precision difference during inference just for
* now.
* Alternatively, if define the Tensor's type during registering, may cause
* a redefinition error.
*/
CPUGather<T, IndexT>(x, index, out);
}
virtual ~GatherCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/gather_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(gather_x86, retrive_op) {
auto gather =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"gather");
ASSERT_FALSE(gather.empty());
int cnt = 0;
for (auto item = gather.begin(); item != gather.end(); ++item) {
cnt++;
ASSERT_TRUE(*item);
}
ASSERT_EQ(cnt, 2);
}
TEST(gather_x86, int32_init) {
GatherCompute<float, int32_t> gather;
ASSERT_EQ(gather.precision(), PRECISION(kFloat));
ASSERT_EQ(gather.target(), TARGET(kX86));
}
TEST(gather_x86, int64_init) {
GatherCompute<float, int64_t> gather;
ASSERT_EQ(gather.precision(), PRECISION(kFloat));
ASSERT_EQ(gather.target(), TARGET(kX86));
}
template <typename T>
void test_case_1dims() {
lite::Tensor x, index, out;
std::vector<int64_t> x_shape{10};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> index_shape{3};
index.Resize(lite::DDim(index_shape));
std::vector<int64_t> out_shape{3};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto index_data = index.mutable_data<T>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<float> index_value{1, 3, 5};
for (int i = 0; i < index.dims().production(); ++i) {
index_data[i] = static_cast<T>(index_value[i]);
}
GatherCompute<float, T> gather;
operators::GatherParam param;
param.X = &x;
param.Index = &index;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
gather.SetContext(std::move(ctx));
gather.SetParam(param);
gather.Run();
std::vector<float> ref_data{1, 3, 5};
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
template <typename T>
void test_case_2dims() {
lite::Tensor x, index, out;
std::vector<int64_t> x_shape{10, 20};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> index_shape{3};
index.Resize(lite::DDim(index_shape));
std::vector<int64_t> out_shape{3, 20};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto index_data = index.mutable_data<T>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<float> index_value{1, 3, 5};
for (int i = 0; i < index.dims().production(); ++i) {
index_data[i] = static_cast<T>(index_value[i]);
}
GatherCompute<float, T> gather;
operators::GatherParam param;
param.X = &x;
param.Index = &index;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
gather.SetContext(std::move(ctx));
gather.SetParam(param);
gather.Run();
std::vector<float> ref_data(60);
for (int i = 0; i < 20; ++i) {
ref_data[i] = static_cast<float>(20 + i);
}
for (int i = 20; i < 40; ++i) {
ref_data[i] = static_cast<float>(40 + i);
}
for (int i = 40; i < 60; ++i) {
ref_data[i] = static_cast<float>(60 + i);
}
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
TEST(gather_x86, run_test_1dims) {
test_case_1dims<int32_t>();
test_case_1dims<int64_t>();
}
TEST(gather_x86, run_test_2dims) {
test_case_2dims<int32_t>();
test_case_2dims<int64_t>();
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(gather, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(gather, kX86, kFloat, kNCHW, int64_in);
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_op.h"
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
......@@ -47,8 +48,28 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> op,
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(paddings.size(), 4L);
CHECK_EQ(dilations.size(), 2L);
if (paddings.size() == 2L) {
for (size_t i = 0; i < strides.size(); ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
}
CHECK_EQ(paddings.size(), 4L)
<< "Paddings size should be the same or twice as the input size.";
std::string padding_algorithm("");
if (op_info->HasAttr("padding_algorithm")) {
padding_algorithm = op_info->GetAttr<std::string>("padding_algorithm");
}
operators::UpdatePaddingAndDilation(&paddings,
&dilations,
strides,
padding_algorithm,
input_dims,
filter_dims);
std::vector<int64_t> output_shape({bs, oc});
for (size_t i = 0; i < 2; i++) {
const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1;
......@@ -59,12 +80,6 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> op,
}
DDim output_dims(output_shape);
bool pads_equal =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]);
if (!pads_equal) {
LOG(FATAL) << "Padding requies pad_top==pad_bottom and pad_lef==pad_right.";
}
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
......
......@@ -52,34 +52,6 @@ inline int ConvOutputSize(int input_size,
return output_size;
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
......
......@@ -137,6 +137,34 @@ class ConvOpLite : public OpLite {
std::string padding_algorithm_{""};
};
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册