提交 b152dbb4 编写于 作者: S Superjomn

add GoogleNet realted kernels unitest

上级 ce6c24e6
......@@ -18,6 +18,18 @@ cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )
cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS relu_compute_x86)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86 operator)
lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
set(x86_kernels
activation_compute_x86
elementwise_compute_x86
......
......@@ -12,88 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
int64_t axis = static_cast<int64_t>(param.axis);
auto out = param.output;
if (axis == 0 && param.x.size() < 10) {
size_t output_offset = 0;
for (auto* in : param.x) {
if (!in || in->dims().production() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims().data());
auto out_stride = framework::stride_numel(out->dims().data());
paddle::operators::StridedNumelCopyWithAxis<T>(
platform::CPUDeviceContext(), axis,
out->mutable_data<T>() + output_offset, out_stride, in->data<T>(),
in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<lite::Tensor> inputs;
for (size_t j = 0; j < param.x.size(); ++j) {
if (param.x[j] && param.x[j]->dims().production() > 0) {
inputs.push_back(*param.x[j]);
} else {
continue;
}
}
int num = inputs.size();
int rows = 1;
auto dim_0 = inputs[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(inputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = inputs[i].dims().production() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
auto output_data = param.output->template mutable_data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = inputs[j].data<float>();
for (int k = 0; k < out_rows; ++k) {
std::memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len);
}
col_idx += col_len;
}
}
}
virtual ~ConcatCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/concat_compute.h"
REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ConcatCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
int64_t axis = static_cast<int64_t>(param.axis);
auto out = param.output;
if (axis == 0 && param.x.size() < 10) {
size_t output_offset = 0;
for (auto* in : param.x) {
if (!in || in->dims().production() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims().data());
auto out_stride = framework::stride_numel(out->dims().data());
paddle::operators::StridedNumelCopyWithAxis<T>(
platform::CPUDeviceContext(), axis,
out->mutable_data<T>() + output_offset, out_stride, in->data<T>(),
in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<lite::Tensor> inputs;
for (size_t j = 0; j < param.x.size(); ++j) {
if (param.x[j] && param.x[j]->dims().production() > 0) {
inputs.push_back(*param.x[j]);
} else {
continue;
}
}
int num = inputs.size();
int rows = 1;
auto dim_0 = inputs[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(inputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = inputs[i].dims().production() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
auto output_data = param.output->template mutable_data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = inputs[j].data<float>();
for (int k = 0; k < out_rows; ++k) {
std::memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len);
}
col_idx += col_len;
}
}
}
virtual ~ConcatCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/concat_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(concat_x86, retrive_op) {
auto concat =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"concat");
ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front());
}
TEST(concat_x86, init) {
ConcatCompute<float> concat;
ASSERT_EQ(concat.precision(), PRECISION(kFloat));
ASSERT_EQ(concat.target(), TARGET(kX86));
}
TEST(concat_x86, run_test) {
lite::Tensor x1, x2, out;
constexpr int batch_size = 1;
std::vector<int64_t> x1_shape{batch_size, 1, 3, 3};
x1.Resize(lite::DDim(x1_shape));
std::vector<int64_t> x2_shape{batch_size, 1, 3, 3};
x2.Resize(lite::DDim(x2_shape));
std::vector<lite::Tensor*> x = {&x1, &x2};
std::vector<int64_t> out_shape{batch_size, 2, 3, 3};
out.Resize(lite::DDim(out_shape));
auto x1_data = x1.mutable_data<float>();
auto x2_data = x2.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x1.dims().production(); i++) {
x1_data[i] = 1;
x2_data[i] = 2;
}
ConcatCompute<float> concat;
operators::ConcatParam param;
param.x = x;
param.output = &out;
param.axis = 1;
concat.SetParam(param);
concat.Run();
std::cout << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
std::cout << out_data[i] << " ";
}
std::cout << std::endl;
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def);
......@@ -12,144 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/conv_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
template <typename T>
class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ConvParam>();
lite::Tensor filter = *param.filter;
param.output->template mutable_data<T>();
const int batch_size = static_cast<int>(param.x->dims()[0]);
std::vector<int64_t> filter_shape_vec(filter.dims().Vectorize());
std::vector<int64_t> output_shape_vec(param.output->dims().Vectorize());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = param.x->dims()[1] / param.groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
lite::DDim col_shape(col_shape_vec);
lite::DDim col_matrix_shape = col_shape.Flattern2D(data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, param.strides, param.paddings,
param.dilations);
lite::Tensor col;
lite::Tensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
col.mutable_data<T>();
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size());
lite::DDim filter_matrix_shape(std::vector<int64_t>{
filter.dims()[0], filter.dims().production() / filter.dims()[0]});
filter.Resize(filter_matrix_shape);
lite::DDim output_matrix_shape(std::vector<int64_t>{
param.output->dims()[1],
param.output->dims().production() /
(param.output->dims()[0] * param.output->dims()[1])});
int in_step = static_cast<int>(param.x->dims()[1]) / param.groups;
int out_step = static_cast<int>(param.output->dims()[1]) / param.groups;
paddle::operators::math::Vol2ColFunctor<platform::CPUDeviceContext, T>
vol2col;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, platform::CPUDeviceContext, T>
im2col;
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
platform::CPUDeviceContext());
for (int i = 0; i < batch_size; i++) {
lite::Tensor in_batch;
in_batch.ShareDataWith(
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
output_matrix_shape.data()));
for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice;
in_slice.ShareDataWith(
in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step));
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(platform::CPUDeviceContext(), in_slice.raw_tensor(),
param.dilations, param.strides,
std::vector<int>{param.paddings[0], param.paddings[1],
param.paddings[0], param.paddings[1]},
&(col.raw_tensor()));
} else if (data_dim == 3U) {
// vol2col
vol2col(platform::CPUDeviceContext(), in_slice.raw_tensor(),
param.dilations, param.strides, param.paddings,
&(col.raw_tensor()));
}
// gemm
lite::Tensor out_slice;
out_slice.ShareDataWith(
out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step));
lite::Tensor filter_slice;
filter_slice.ShareDataWith(
filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step));
blas.MatMul(filter_slice.raw_tensor(), false, col_matrix.raw_tensor(),
false, T(1.0), &(out_slice.raw_tensor()), T(0.0));
}
}
}
virtual ~Conv2dCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/conv_compute.h"
REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::Conv2dCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/conv_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
template <typename T>
class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ConvParam>();
lite::Tensor filter = *param.filter;
param.output->template mutable_data<T>();
const int batch_size = static_cast<int>(param.x->dims()[0]);
std::vector<int64_t> filter_shape_vec(filter.dims().Vectorize());
std::vector<int64_t> output_shape_vec(param.output->dims().Vectorize());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = param.x->dims()[1] / param.groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
lite::DDim col_shape(col_shape_vec);
lite::DDim col_matrix_shape = col_shape.Flattern2D(data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, param.strides, param.paddings,
param.dilations);
lite::Tensor col;
lite::Tensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
col.mutable_data<T>();
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size());
lite::DDim filter_matrix_shape(std::vector<int64_t>{
filter.dims()[0], filter.dims().production() / filter.dims()[0]});
filter.Resize(filter_matrix_shape);
lite::DDim output_matrix_shape(std::vector<int64_t>{
param.output->dims()[1],
param.output->dims().production() /
(param.output->dims()[0] * param.output->dims()[1])});
int in_step = static_cast<int>(param.x->dims()[1]) / param.groups;
int out_step = static_cast<int>(param.output->dims()[1]) / param.groups;
paddle::operators::math::Vol2ColFunctor<platform::CPUDeviceContext, T>
vol2col;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, platform::CPUDeviceContext, T>
im2col;
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
platform::CPUDeviceContext());
for (int i = 0; i < batch_size; i++) {
lite::Tensor in_batch;
in_batch.ShareDataWith(
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
output_matrix_shape.data()));
for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice;
in_slice.ShareDataWith(
in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step));
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(platform::CPUDeviceContext(), in_slice.raw_tensor(),
param.dilations, param.strides,
std::vector<int>{param.paddings[0], param.paddings[1],
param.paddings[0], param.paddings[1]},
&(col.raw_tensor()));
} else if (data_dim == 3U) {
// vol2col
vol2col(platform::CPUDeviceContext(), in_slice.raw_tensor(),
param.dilations, param.strides, param.paddings,
&(col.raw_tensor()));
}
// gemm
lite::Tensor out_slice;
out_slice.ShareDataWith(
out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step));
lite::Tensor filter_slice;
filter_slice.ShareDataWith(
filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step));
blas.MatMul(filter_slice.raw_tensor(), false, col_matrix.raw_tensor(),
false, T(1.0), &(out_slice.raw_tensor()), T(0.0));
}
}
}
virtual ~Conv2dCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/conv_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(conv_x86, retrive_op) {
auto conv2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"conv2d");
ASSERT_FALSE(conv2d.empty());
ASSERT_TRUE(conv2d.front());
}
TEST(conv2d_x86, init) {
Conv2dCompute<float> conv2d;
ASSERT_EQ(conv2d.precision(), PRECISION(kFloat));
ASSERT_EQ(conv2d.target(), TARGET(kX86));
}
TEST(conv2d_x86, run_test) {
lite::Tensor x, filter, b, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 3, 3};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> filter_shape{1, 3, 3, 3};
filter.Resize(lite::DDim(filter_shape));
std::vector<int64_t> b_shape{1, 3, 1, 1};
b.Resize(lite::DDim(b_shape));
std::vector<int64_t> out_shape{batch_size, 1, 1, 1};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto filter_data = filter.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = 1;
}
for (int64_t i = 0; i < filter.dims().production(); i++) {
filter_data[i] = 1;
}
for (int64_t i = 0; i < b.dims().production(); i++) {
b_data[i] = 0;
}
Conv2dCompute<float> conv2d;
operators::ConvParam param;
param.x = &x;
param.filter = &filter;
param.bias = &b;
param.output = &out;
param.strides = {1, 1};
param.paddings = {0, 0};
param.groups = 1;
param.dilations = {1, 1};
conv2d.SetParam(param);
conv2d.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i] << " ";
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def);
......@@ -12,72 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::DropoutParam;
void Run() override {
auto& param = *param_.get_mutable<operators::DropoutParam>();
const auto* x_data = param.x->data<T>();
auto* out_data = param.output->template mutable_data<T>();
if (!param.is_test) {
auto* mask_data = param.mask->template mutable_data<T>();
std::random_device rnd;
std::minstd_rand engine;
int seed = param.fix_seed ? param.seed : rnd();
engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(param.mask->dims().data());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < param.dropout_prob) {
mask_data[i] = 0;
out_data[i] = 0;
} else {
if (param.dropout_implementation == "upscale_in_train") {
mask_data[i] = 1.0f / static_cast<T>(1.0f - param.dropout_prob);
out_data[i] = x_data[i] / static_cast<T>(1.0f - param.dropout_prob);
} else {
mask_data[i] = 1;
out_data[i] = x_data[i];
}
}
}
} else {
auto X = EigenMatrix<T>::Reshape(param.x->raw_tensor(), 1);
auto Y = EigenMatrix<T>::Reshape(param.output->raw_tensor(), 1);
auto& place = *platform::CPUDeviceContext().eigen_device();
if (param.dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - param.dropout_prob);
}
}
}
virtual ~DropoutCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/dropout_compute.h"
REGISTER_LITE_KERNEL(dropout, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::DropoutCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::DropoutParam;
void Run() override {
auto& param = *param_.get_mutable<operators::DropoutParam>();
const auto* x_data = param.x->data<T>();
auto* out_data = param.output->template mutable_data<T>();
if (!param.is_test) {
auto* mask_data = param.mask->template mutable_data<T>();
std::random_device rnd;
std::minstd_rand engine;
int seed = param.fix_seed ? param.seed : rnd();
engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(param.mask->dims().data());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < param.dropout_prob) {
mask_data[i] = 0;
out_data[i] = 0;
} else {
if (param.dropout_implementation == "upscale_in_train") {
mask_data[i] = 1.0f / static_cast<T>(1.0f - param.dropout_prob);
out_data[i] = x_data[i] / static_cast<T>(1.0f - param.dropout_prob);
} else {
mask_data[i] = 1;
out_data[i] = x_data[i];
}
}
}
} else {
auto X = EigenMatrix<T>::Reshape(param.x->raw_tensor(), 1);
auto Y = EigenMatrix<T>::Reshape(param.output->raw_tensor(), 1);
auto& place = *platform::CPUDeviceContext().eigen_device();
if (param.dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - param.dropout_prob);
}
}
}
virtual ~DropoutCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/dropout_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(dropout_x86, retrive_op) {
auto dropout =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"dropout");
ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front());
}
TEST(dropout_x86, init) {
DropoutCompute<float> dropout;
ASSERT_EQ(dropout.precision(), PRECISION(kFloat));
ASSERT_EQ(dropout.target(), TARGET(kX86));
}
TEST(dropout_x86, run_test) {
lite::Tensor x, y, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 2, 2};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
// DropoutCompute dropout;
DropoutCompute<float> dropout;
operators::DropoutParam param;
param.x = &x;
param.dropout_prob = 0.25;
param.is_test = true;
param.fix_seed = true;
param.output = &out;
dropout.SetParam(param);
dropout.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
......@@ -12,113 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
class ElementwiseSubCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<SubFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, SubFunctor<T>(),
&param.Out->raw_tensor());
}
virtual ~ElementwiseSubCompute() = default;
};
template <typename T>
struct SubGradDX {
T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
struct SubGradDY {
T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename T>
class ElementwiseSubGradCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseGradParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>();
param.Y_grad->template mutable_data<T>();
// skip out, x, y
auto dout = param.Out_grad->raw_tensor();
auto dx = param.X_grad->raw_tensor();
auto dy = param.Y_grad->raw_tensor();
auto& skip = dout;
paddle::operators::ElemwiseExplicitGradCompute<
platform::CPUDeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
*context.x86_execution_context(), skip, skip, skip, dout, param.axis,
&dx, &dy, SubGradDX<T>(), SubGradDY<T>());
}
virtual ~ElementwiseSubGradCompute() = default;
};
template <typename T>
class ElementwiseAddCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<AddFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, AddFunctor<T>(),
&param.Out->raw_tensor());
}
virtual ~ElementwiseAddCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// float
REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ElementwiseSubCompute<float>,
def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
class ElementwiseSubCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<SubFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, SubFunctor<T>(),
&param.Out->raw_tensor());
}
virtual ~ElementwiseSubCompute() = default;
};
template <typename T>
struct SubGradDX {
T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
struct SubGradDY {
T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename T>
class ElementwiseSubGradCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseGradParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>();
param.Y_grad->template mutable_data<T>();
// skip out, x, y
auto dout = param.Out_grad->raw_tensor();
auto dx = param.X_grad->raw_tensor();
auto dy = param.Y_grad->raw_tensor();
auto& skip = dout;
paddle::operators::ElemwiseExplicitGradCompute<
platform::CPUDeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
*context.x86_execution_context(), skip, skip, skip, dout, param.axis,
&dx, &dy, SubGradDX<T>(), SubGradDY<T>());
}
virtual ~ElementwiseSubGradCompute() = default;
};
template <typename T>
class ElementwiseAddCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context());
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<AddFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context(), &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, AddFunctor<T>(),
&param.Out->raw_tensor());
}
virtual ~ElementwiseAddCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(elementwise_add_x86, retrive_op) {
auto elementwise_add =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"elementwise_add");
ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front());
}
TEST(elementwise_add_x86, init) {
ElementwiseAddCompute<float> elementwise_add;
ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat));
ASSERT_EQ(elementwise_add.target(), TARGET(kX86));
}
TEST(elementwise_add_x86, run_test) {
lite::Tensor x, y, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 2, 2};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> y_shape{batch_size, 3, 2, 2};
y.Resize(lite::DDim(y_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto y_data = y.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = 1;
}
for (int64_t i = 0; i < y.dims().production(); i++) {
y_data[i] = 2;
}
// ElementwiseAddCompute elementwise_add;
ElementwiseAddCompute<float> elementwise_add;
operators::ElementwiseParam param;
param.X = &x;
param.Y = &y;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
elementwise_add.SetParam(param);
elementwise_add.SetContext(std::move(ctx));
elementwise_add.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
......@@ -12,89 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/operators/fc_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W;
if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_w);
Out = Out.array().rowwise() + B.transpose().array();
}
}
template <typename T>
void fc_compute_naive(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
CHECK_EQ(x_w, w_h);
// out shape: (x_h, w_w)
memset(out, 0, x_h * w_w * sizeof(T));
for (int i = 0; i < x_h; i++) {
for (int j = 0; j < w_w; j++) {
T tmp = static_cast<T>(0);
for (int k = 0; k < x_w; k++) {
tmp += x[i * x_w + k] * w[k * w_w + j];
}
out[i * w_w + j] = tmp + b[j];
}
}
}
template <typename T>
class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::FcParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
fc_compute_eigen(
param.input->data<T>(), // x
param.input->dims().Slice(0, param.in_num_col_dims).production(),
param.input->dims()
.Slice(param.in_num_col_dims, param.input->dims().size())
.production(),
param.w->data<T>(), // w
param.w->dims()[0], // w_h
param.w->dims()[1], // w_w
param.bias->data<T>(), // b
param.output->mutable_data<T>());
}
virtual ~FcCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/fc_compute.h"
REGISTER_LITE_KERNEL(fc, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::FcCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/operators/fc_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W;
if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_w);
Out = Out.array().rowwise() + B.transpose().array();
}
}
template <typename T>
void fc_compute_naive(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
CHECK_EQ(x_w, w_h);
// out shape: (x_h, w_w)
memset(out, 0, x_h * w_w * sizeof(T));
for (int i = 0; i < x_h; i++) {
for (int j = 0; j < w_w; j++) {
T tmp = static_cast<T>(0);
for (int k = 0; k < x_w; k++) {
tmp += x[i * x_w + k] * w[k * w_w + j];
}
out[i * w_w + j] = tmp + b[j];
}
}
}
template <typename T>
class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::FcParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
fc_compute_eigen(
param.input->data<T>(), // x
param.input->dims().Slice(0, param.in_num_col_dims).production(),
param.input->dims()
.Slice(param.in_num_col_dims, param.input->dims().size())
.production(),
param.w->data<T>(), // w
param.w->dims()[0], // w_h
param.w->dims()[1], // w_w
param.bias->data<T>(), // b
param.output->mutable_data<T>());
}
virtual ~FcCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/fc_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(fc_x86, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
TEST(fc_x86, init) {
FcCompute<float> fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kX86));
}
TEST(fc_x86, run_test) {
lite::Tensor x, w, b, out;
constexpr int batch_size = 2;
std::vector<int64_t> x_shape{batch_size, 3};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> w_shape{3, 4};
w.Resize(lite::DDim(w_shape));
std::vector<int64_t> b_shape{1, 4};
b.Resize(lite::DDim(b_shape));
std::vector<int64_t> out_shape{1, 4};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().production(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().production(); i++) {
b_data[i] = static_cast<float>(i);
}
/* lite::x86::math::fc_compute_eigen(x_data, batch_size, 3, //
w_data, 3, 4, //
b_data, ref_data); */
// FcCompute fc;
FcCompute<float> fc;
operators::FcParam param;
param.in_num_col_dims = 1;
param.input = &x;
param.w = &w;
param.bias = &b;
param.output = &out;
param.in_mat_dims = x.dims();
// std::unique_ptr<KernelContext> ctx(new KernelContext);
// ctx->As<X86Context>();
fc.SetParam(param);
// fc.SetContext(std::move(ctx));
fc.Run();
VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().production(); i++) {
VLOG(3) << out_data[i];
}
/* for (int i = 0; i < out.dims().product(); ++i) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}*/
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
......@@ -12,122 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
using Tensor = framework::Tensor;
template <typename T>
class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulParam>();
CHECK(context.x86_device_context());
param.output->template mutable_data<T>();
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix(
*x, param.x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix(
*y, param.y_num_col_dims)
: *y;
auto* z = &param.output->raw_tensor();
auto z_dim = z->dims();
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context());
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
}
virtual ~MulCompute() = default;
};
template <typename T>
class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulGradParam>();
CHECK(context.x86_device_context());
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, param.x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, param.y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize(
{framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]});
auto* dx = &param.x_grad->raw_tensor();
auto* dy = &param.y_grad->raw_tensor();
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context());
if (dx) {
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
param.x_grad->template mutable_data<T>();
Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
*dx, param.x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
// dy->yutable_data<T>(context.x86_device_context->GetPlace());
param.y_grad->template mutable_data<T>();
Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
*dy, param.y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
virtual ~MulGradCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
REGISTER_LITE_KERNEL(mul, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::MulCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
using Tensor = framework::Tensor;
template <typename T>
class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulParam>();
CHECK(context.x86_device_context());
param.output->template mutable_data<T>();
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix(
*x, param.x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix(
*y, param.y_num_col_dims)
: *y;
auto* z = &param.output->raw_tensor();
auto z_dim = z->dims();
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context());
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
}
virtual ~MulCompute() = default;
};
template <typename T>
class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulGradParam>();
CHECK(context.x86_device_context());
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, param.x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, param.y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize(
{framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]});
auto* dx = &param.x_grad->raw_tensor();
auto* dy = &param.y_grad->raw_tensor();
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context());
if (dx) {
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
param.x_grad->template mutable_data<T>();
Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
*dx, param.x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
// dy->yutable_data<T>(context.x86_device_context->GetPlace());
param.y_grad->template mutable_data<T>();
Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
*dy, param.y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
virtual ~MulGradCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(mul_x86, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
TEST(mul_x86, init) {
MulCompute<float> mul;
ASSERT_EQ(mul.precision(), PRECISION(kFloat));
ASSERT_EQ(mul.target(), TARGET(kX86));
}
TEST(mul_x86, run_test) {
lite::Tensor x, y, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> y_shape{3, 4};
y.Resize(lite::DDim(y_shape));
std::vector<int64_t> out_shape{batch_size, 4};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto y_data = y.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < y.dims().production(); i++) {
y_data[i] = static_cast<float>(i);
}
// MulCompute mul;
MulCompute<float> mul;
operators::MulParam param;
param.x = &x;
param.y = &y;
param.output = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
mul.SetContext(std::move(ctx));
mul.SetParam(param);
mul.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
......@@ -12,69 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::PoolParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
if (param.global_pooling) {
for (size_t i = 0; i < param.ksize.size(); ++i) {
param.paddings[i] = 0;
param.ksize[i] = static_cast<int>(param.x->dims()[i + 2]);
}
}
switch (param.ksize.size()) {
case 2: {
if (param.pooling_type == "max") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext, paddle::operators::math::MaxPool<T>,
T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(),
param.ksize, param.strides, param.paddings,
pool_process, true, false,
&(param.output->raw_tensor()));
} else if (param.pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext, paddle::operators::math::AvgPool<T>,
T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(),
param.ksize, param.strides, param.paddings,
pool_process, param.exclusive, param.adaptive,
&(param.output->raw_tensor()));
}
} break;
case 3: {
} break;
}
}
virtual ~PoolCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/pool_compute.h"
REGISTER_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::PoolCompute<float>, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("x", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::PoolParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
if (param.global_pooling) {
for (size_t i = 0; i < param.ksize.size(); ++i) {
param.paddings[i] = 0;
param.ksize[i] = static_cast<int>(param.x->dims()[i + 2]);
}
}
switch (param.ksize.size()) {
case 2: {
if (param.pooling_type == "max") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext, paddle::operators::math::MaxPool<T>,
T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(),
param.ksize, param.strides, param.paddings,
pool_process, true, false,
&(param.output->raw_tensor()));
} else if (param.pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext, paddle::operators::math::AvgPool<T>,
T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(),
param.ksize, param.strides, param.paddings,
pool_process, param.exclusive, param.adaptive,
&(param.output->raw_tensor()));
}
} break;
case 3: {
} break;
}
}
virtual ~PoolCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/pool_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(pool_x86, retrive_op) {
auto pool2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"pool2d");
ASSERT_FALSE(pool2d.empty());
ASSERT_TRUE(pool2d.front());
}
TEST(pool2d_x86, init) {
PoolCompute<float> pool2d;
ASSERT_EQ(pool2d.precision(), PRECISION(kFloat));
ASSERT_EQ(pool2d.target(), TARGET(kX86));
}
TEST(pool2d_x86, run_test) {
lite::Tensor x, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 4, 4};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
PoolCompute<float> pool2d;
operators::PoolParam param;
param.x = &x;
param.output = &out;
param.strides = {2, 2};
param.paddings = {0, 0};
param.ksize = {2, 2};
param.pooling_type = "max";
pool2d.SetParam(param);
pool2d.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def);
......@@ -12,42 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/operators/relu_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReluParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto n = param.input->dims().production();
const float* input = param.input->data<float>();
float* output = param.output->mutable_data<float>();
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
}
virtual ~ReluCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/relu_compute.h"
REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ReluCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/operators/relu_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReluParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto n = param.input->dims().production();
const float* input = param.input->data<float>();
float* output = param.output->mutable_data<float>();
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
}
virtual ~ReluCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/relu_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(relu_x86, retrive_op) {
auto relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front());
}
TEST(relu_x86, init) {
ReluCompute<float> relu;
ASSERT_EQ(relu.precision(), PRECISION(kFloat));
ASSERT_EQ(relu.target(), TARGET(kX86));
}
TEST(relu_x86, run_test) {
lite::Tensor x, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 2, 2};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
int sign = i % 2 == 0 ? 1 : -1;
x_data[i] = static_cast<float>(i * sign);
}
// ReluCompute relu;
ReluCompute<float> relu;
operators::ReluParam param;
param.input = &x;
param.output = &out;
relu.SetParam(param);
relu.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def);
......@@ -12,48 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/operators/relu_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void scale_compute(const T* x, T* out, int size, float scale, float bias,
bool bias_before) {
if (bias_before) bias *= scale;
for (int i = 0; i < size; i++) {
out[i] = x[i] * scale + bias;
}
}
template <typename T>
class ScaleCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ScaleParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
scale_compute(param.x->data<T>(), param.output->mutable_data<T>(),
param.x->dims().production(), param.scale, param.bias,
param.bias_after_scale);
}
virtual ~ScaleCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/scale_compute.h"
REGISTER_LITE_KERNEL(scale, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ScaleCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/operators/relu_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void scale_compute(const T* x, T* out, int size, float scale, float bias,
bool bias_before) {
if (bias_before) bias *= scale;
for (int i = 0; i < size; i++) {
out[i] = x[i] * scale + bias;
}
}
template <typename T>
class ScaleCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ScaleParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
scale_compute(param.x->data<T>(), param.output->mutable_data<T>(),
param.x->dims().production(), param.scale, param.bias,
param.bias_after_scale);
}
virtual ~ScaleCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/scale_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(scale_x86, retrive_op) {
auto scale =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("scale");
ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front());
}
TEST(scale_x86, init) {
ScaleCompute<float> scale;
ASSERT_EQ(scale.precision(), PRECISION(kFloat));
ASSERT_EQ(scale.target(), TARGET(kX86));
}
TEST(scale_x86, run_test) {
lite::Tensor x, y, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 2, 2};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
// ScaleCompute scale;
ScaleCompute<float> scale;
operators::ScaleParam param;
param.x = &x;
param.scale = 0.5;
param.bias = 0;
param.output = &out;
scale.SetParam(param);
scale.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
......@@ -12,76 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, lite::DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, lite::DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T>
class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SoftmaxParam;
void Run() override {
auto& param = *param_.get_mutable<operators::SoftmaxParam>();
// auto& context = context_->As<X86Context>();
CHECK(param.output);
CHECK(param.x);
const int rank = param.x->dims().size();
const int axis = CanonicalAxis(param.axis, rank);
int axis_dim = param.x->dims()[axis];
const int n = SizeToAxis(axis, param.x->dims());
const int d = SizeFromAxis(axis, param.x->dims());
std::vector<int64_t> shape{n, d};
lite::Tensor input_2d, out_2d;
input_2d.ShareDataWith(*param.x);
input_2d.Resize(lite::DDim(shape));
out_2d.ShareDataWith(*param.output);
out_2d.Resize(lite::DDim(shape));
paddle::operators::math::SoftmaxFunctor<platform::CPUDeviceContext, T,
true>()(
platform::CPUDeviceContext(), axis_dim, &input_2d.raw_tensor(),
&out_2d.raw_tensor());
}
virtual ~SoftmaxCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
#include "paddle/fluid/lite/kernels/x86/softmax_compute.h"
REGISTER_LITE_KERNEL(softmax, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::SoftmaxCompute<float>, def)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, lite::DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, lite::DDim dims) {
int size = 1;
for (size_t i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T>
class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SoftmaxParam;
void Run() override {
auto& param = *param_.get_mutable<operators::SoftmaxParam>();
// auto& context = context_->As<X86Context>();
CHECK(param.output);
CHECK(param.x);
const int rank = param.x->dims().size();
const int axis = CanonicalAxis(param.axis, rank);
int axis_dim = param.x->dims()[axis];
const int n = SizeToAxis(axis, param.x->dims());
const int d = SizeFromAxis(axis, param.x->dims());
std::vector<int64_t> shape{n, d};
lite::Tensor input_2d, out_2d;
input_2d.ShareDataWith(*param.x);
input_2d.Resize(lite::DDim(shape));
out_2d.ShareDataWith(*param.output);
out_2d.Resize(lite::DDim(shape));
paddle::operators::math::SoftmaxFunctor<platform::CPUDeviceContext, T,
true>()(
platform::CPUDeviceContext(), axis_dim, &input_2d.raw_tensor(),
&out_2d.raw_tensor());
}
virtual ~SoftmaxCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/x86/softmax_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(softmax_x86, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
TEST(softmax_x86, init) {
SoftmaxCompute<float> softmax;
ASSERT_EQ(softmax.precision(), PRECISION(kFloat));
ASSERT_EQ(softmax.target(), TARGET(kX86));
}
TEST(softmax_x86, run_test) {
lite::Tensor x, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 3, 3};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 3, 3};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
SoftmaxCompute<float> softmax;
operators::SoftmaxParam param;
param.x = &x;
param.output = &out;
softmax.SetParam(param);
softmax.Run();
LOG(INFO) << "output: ";
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
......@@ -52,7 +52,7 @@ class DropoutOpLite : public OpLite {
param_.mask = GetMutableVar<lite::Tensor>(scope, Mask);
param_.dropout_prob = op_desc.GetAttr<float>("dropout_prob");
if (op_desc.HasAttr("axis")) {
if (op_desc.HasAttr("is_test")) {
param_.is_test = op_desc.GetAttr<bool>("is_test");
}
param_.fix_seed = op_desc.GetAttr<bool>("fix_seed");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册