未验证 提交 9dcd9914 编写于 作者: W Wilber 提交者: GitHub

add var_conv_2d op, x86 kernel and unit test test=develop (#2422)

- add var_conv_2d op

- add var_conv_2d x86 kernel

- add var_conv_2d x86 test
上级 1f075a8b
...@@ -42,6 +42,7 @@ add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_k ...@@ -42,6 +42,7 @@ add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_k
add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type)
if(NOT LITE_WITH_X86) if(NOT LITE_WITH_X86)
return() return()
...@@ -76,3 +77,4 @@ if(LITE_BUILD_EXTRA) ...@@ -76,3 +77,4 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86) lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86)
endif() endif()
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86) lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/var_conv_2d_compute.h"
REGISTER_LITE_KERNEL(var_conv_2d,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::VarConv2DCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("ROW", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Col", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::VarConv2DParam;
void Im2Col(const lite::Tensor& input, lite::Tensor* col) const {
auto& param = *param_.get_mutable<param_t>();
int input_channel = param.input_channel;
int kernel_h = param.kernel_h;
int kernel_w = param.kernel_w;
int stride_h = param.stride_h;
int stride_w = param.stride_w;
auto* in_row = param.ROW;
auto* in_col = param.COLUMN;
int batch = input.lod()[0].size() - 1;
const auto& bottom_offset = input.lod()[0];
// 2-D lod info.
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// top offset is the whole size of each data sample
std::vector<uint64_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_x = top_im_x * top_im_y;
int top_y = input_channel * kernel_h * kernel_w;
top_size += top_y * top_x;
top_offset.push_back(top_size);
}
// std::vector<int64_t> col_lod_vec;
// col_lod_vec.push_back(top_offset);
LoD col_lod;
col_lod.push_back(top_offset);
col->set_lod(col_lod);
std::vector<int64_t> col_dims_vec{top_size};
col_dims_vec.push_back(1);
col->Resize(col_dims_vec);
auto* top_data = col->mutable_data<T>();
const auto* bottom_data = input.data<T>();
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
top_data[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] =
bottom_data[b_offset + im_offset + im_y * width + im_x];
} else {
top_data[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] = 0;
}
}
}
}
}
}
}
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
auto* bottom = param.X;
auto* in_row = param.ROW;
auto* in_col = param.COLUMN;
auto* w = param.W;
auto* top = param.Out;
auto* col = param.Col;
int output_channel = param.output_channel;
int input_channel = param.input_channel;
int kernel_h = param.kernel_h;
int kernel_w = param.kernel_w;
int stride_h = param.stride_h;
int stride_w = param.stride_w;
Im2Col(*bottom, col);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
LoD top_lod;
top_lod.push_back(top_offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
top->Resize(top_dims_vec);
auto* top_data = top->mutable_data<T>();
const auto* w_data = w->data<T>();
const auto* col_data = col->data<T>();
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
blas.GEMM(false,
false,
output_channel,
top_im_size,
input_channel * kernel_h * kernel_w,
1.0,
w_data,
input_channel * kernel_h * kernel_w,
col_data + col_offset[b],
top_im_size,
0.0,
top_data + top_offset[b],
top_im_size);
}
}
virtual ~VarConv2DCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/var_conv_2d_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
static void im2col_ref(const lite::Tensor& input,
const lite::Tensor* in_row,
const lite::Tensor* in_col,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int input_channel,
lite::Tensor* col) {
int batch = input.lod()[0].size() - 1;
const auto& bottom_offset = input.lod()[0];
// 2-D lod info.
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// top offset is the whole size of each data sample
std::vector<uint64_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_x = top_im_x * top_im_y;
int top_y = input_channel * kernel_h * kernel_w;
top_size += top_y * top_x;
top_offset.push_back(top_size);
}
LoD col_lod;
col_lod.push_back(top_offset);
col->set_lod(col_lod);
std::vector<int64_t> col_dims_vec{top_size};
col_dims_vec.push_back(1);
col->Resize(col_dims_vec);
auto* top_data = col->mutable_data<float>();
const auto* bottom_data = input.data<float>();
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
top_data[t_offset + (row_offset + ky * kernel_w + kx) * top_x +
col_offset] =
bottom_data[b_offset + im_offset + im_y * width + im_x];
} else {
top_data[t_offset + (row_offset + ky * kernel_w + kx) * top_x +
col_offset] = 0;
}
}
}
}
}
}
}
}
static void var_conv_2d_ref(const lite::Tensor* bottom,
const lite::Tensor* w,
const lite::Tensor* in_row,
const lite::Tensor* in_col,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int input_channel,
const int output_channel,
lite::Tensor* top,
lite::Tensor* col) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<X86Context>();
im2col_ref(*bottom,
in_row,
in_col,
kernel_h,
kernel_w,
stride_h,
stride_w,
input_channel,
col);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
LoD top_lod;
top_lod.push_back(top_offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
top->Resize(top_dims_vec);
auto* top_data = top->mutable_data<float>();
const auto* w_data = w->data<float>();
const auto* col_data = col->data<float>();
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, float>(context);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
blas.GEMM(false,
false,
output_channel,
top_im_size,
input_channel * kernel_h * kernel_w,
1.0,
w_data,
input_channel * kernel_h * kernel_w,
col_data + col_offset[b],
top_im_size,
0.0,
top_data + top_offset[b],
top_im_size);
}
}
TEST(var_conv_2d_x86, retrive_op) {
auto var_conv_2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"var_conv_2d");
ASSERT_FALSE(var_conv_2d.empty());
ASSERT_TRUE(var_conv_2d.front());
}
TEST(var_conv_2d_x86, init) {
VarConv2DCompute<float> var_conv_2d;
ASSERT_EQ(var_conv_2d.precision(), PRECISION(kFloat));
ASSERT_EQ(var_conv_2d.target(), TARGET(kX86));
}
TEST(var_conv_2d_x86, run_test) {
VarConv2DCompute<float> var_conv_2d;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
operators::VarConv2DParam param;
lite::Tensor X, W, ROW, COLUMN;
lite::Tensor Out, Col;
int kernel_h, kernel_w;
int stride_h, stride_w;
int input_channel, output_channel;
output_channel = 5;
input_channel = 5;
kernel_h = 5;
kernel_w = 5;
stride_h = 1;
stride_w = 1;
std::vector<int64_t> w_dims_vec;
w_dims_vec.push_back(output_channel);
w_dims_vec.push_back(input_channel * kernel_h * kernel_w);
W.Resize(w_dims_vec);
auto* w_data = W.mutable_data<float>();
for (int i = 0; i < W.numel(); ++i) {
w_data[i] = i - 1.f;
}
std::vector<uint64_t> row_lod_vec{0, 10, 20};
LoD row_lod;
row_lod.push_back(row_lod_vec);
ROW.set_lod(row_lod);
std::vector<uint64_t> column_lod_vec{0, 10, 20};
LoD column_lod;
column_lod.push_back(column_lod_vec);
COLUMN.set_lod(column_lod);
int x_size = 0;
std::vector<uint64_t> x_lod_vec;
for (size_t i = 0; i < row_lod_vec.size() - 1; ++i) {
int height = row_lod_vec[i + 1] - row_lod_vec[i];
int width = column_lod_vec[i + 1] - column_lod_vec[i];
x_lod_vec.push_back(height * width);
x_size += height * width;
}
x_size *= input_channel;
std::vector<int64_t> x_dims_vec{x_size, 1};
LoD x_lod;
x_lod.push_back(x_lod_vec);
X.Resize(x_dims_vec);
X.set_lod(x_lod);
auto* x_data = X.mutable_data<float>();
for (int i = 0; i < X.numel(); ++i) {
x_data[i] = i % 20 * 1.f;
}
param.X = &X;
param.W = &W;
param.ROW = &ROW;
param.COLUMN = &COLUMN;
param.Out = &Out;
param.Col = &Col;
param.stride_h = stride_h;
param.stride_w = stride_w;
param.kernel_h = kernel_h;
param.kernel_w = kernel_w;
param.input_channel = input_channel;
param.output_channel = output_channel;
var_conv_2d.SetParam(param);
var_conv_2d.SetContext(std::move(ctx));
var_conv_2d.Run();
lite::Tensor top_ref, col_ref;
var_conv_2d_ref(&X,
&W,
&ROW,
&COLUMN,
kernel_h,
kernel_w,
stride_h,
stride_w,
input_channel,
output_channel,
&top_ref,
&col_ref);
for (int i = 0; i < Out.numel(); ++i) {
EXPECT_NEAR(Out.data<float>()[i], top_ref.data<float>()[i], 1e-5);
}
for (int i = 0; i < Col.numel(); ++i) {
EXPECT_NEAR(Col.data<float>()[i], col_ref.data<float>()[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(var_conv_2d, kX86, kFloat, kNCHW, def);
...@@ -81,6 +81,7 @@ add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${o ...@@ -81,6 +81,7 @@ add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${o
add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS}) add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
...@@ -789,6 +789,22 @@ struct ReduceParam { ...@@ -789,6 +789,22 @@ struct ReduceParam {
bool reduce_all{false}; bool reduce_all{false};
}; };
struct VarConv2DParam {
const lite::Tensor* X{};
const lite::Tensor* ROW{};
const lite::Tensor* COLUMN{};
const lite::Tensor* W{};
lite::Tensor* Out{};
lite::Tensor* Col{};
int input_channel;
int output_channel;
int stride_h;
int stride_w;
int kernel_h;
int kernel_w;
};
/// ----------------------- shape operators ---------------------- /// ----------------------- shape operators ----------------------
struct ShapeParam { struct ShapeParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/var_conv_2d_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool VarConv2dOp::CheckShape() const {
auto x_dims = param_.X->dims();
CHECK_EQ(x_dims.size(), 2) << "The rank of X(Input) can't be less than 2.";
auto w_dims = param_.W->dims();
CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor";
CHECK_EQ(w_dims[0], param_.output_channel)
<< "W dim[0] should be equal to OutputChannel";
CHECK_EQ(w_dims[1], param_.input_channel * param_.kernel_h * param_.kernel_w)
<< "W dim[1] should be equal to InputChannel * KernelH * KernelW";
LoD x_lod = param_.X->lod();
CHECK_EQ(x_lod.empty(), false) << "The Input(X) must hold lod info.";
CHECK_GE(x_lod.size(), 1) << "The Input(X)'s lod info is corrupted.";
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod[0].back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
LoD row_lod = param_.ROW->lod();
CHECK_EQ(row_lod.empty(), false) << "The Input(ROW) must hold lod info.";
LoD col_lod = param_.COLUMN->lod();
CHECK_EQ(col_lod.empty(), false) << "The Input(COLUMN) must hold lod info.";
return true;
}
bool VarConv2dOp::InferShape() const { return true; }
bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.ROW = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("ROW").front())->Get<lite::Tensor>());
param_.COLUMN = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("COLUMN").front())->Get<lite::Tensor>());
param_.W = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("W").front())->Get<lite::Tensor>());
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.Col =
scope->FindVar(opdesc.Output("Col").front())->GetMutable<lite::Tensor>();
CHECK(param_.X) << "X(Input) of VarConv2dOP should not be null.";
CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null.";
CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null.";
CHECK(param_.W) << "W(Input) of VarConv2dOP should not be null.";
CHECK(param_.Out) << "Out(Output) of VarConv2dOP should not be null.";
CHECK(param_.Col) << "Col(Output) of VarConv2dOP should not be null.";
param_.output_channel = opdesc.GetAttr<int>("OutputChannel");
param_.input_channel = opdesc.GetAttr<int>("InputChannel");
param_.kernel_h = opdesc.GetAttr<int>("KernelH");
param_.kernel_w = opdesc.GetAttr<int>("KernelW");
param_.stride_h = opdesc.GetAttr<int>("StrideH");
param_.stride_w = opdesc.GetAttr<int>("StrideW");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(var_conv_2d, paddle::lite::operators::VarConv2dOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class VarConv2dOp : public OpLite {
public:
VarConv2dOp() {}
explicit VarConv2dOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "var_conv_2d"; }
private:
mutable VarConv2DParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册