未验证 提交 aa507f9b 编写于 作者: L liu zhengxi 提交者: GitHub

Enable pool2d, dropout, transpose and transpose2 op on x86 (#2226)

* enable pool2d op on x86 and add its unit tests, test=develop

* enable dropout op and add its unit tests, test=develop

* add tranpose, transpose2 op and add their unit tests, test=develop
上级 76e74ef1
...@@ -30,7 +30,7 @@ template <typename PoolProcess, typename T> ...@@ -30,7 +30,7 @@ template <typename PoolProcess, typename T>
class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> { class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> {
public: public:
void operator()(const lite::X86Context& context, void operator()(const lite::X86Context& context,
const lite::Tensor& input, const lite::Tensor* input,
const std::vector<int>& ksize, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
...@@ -38,9 +38,9 @@ class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> { ...@@ -38,9 +38,9 @@ class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> {
bool exclusive, bool exclusive,
bool adaptive, bool adaptive,
lite::Tensor* output) { lite::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input->dims()[0];
const int input_height = input.dims()[2]; const int input_height = input->dims()[2];
const int input_width = input.dims()[3]; const int input_width = input->dims()[3];
const int output_channels = output->dims()[1]; const int output_channels = output->dims()[1];
const int output_height = output->dims()[2]; const int output_height = output->dims()[2];
const int output_width = output->dims()[3]; const int output_width = output->dims()[3];
...@@ -54,7 +54,7 @@ class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> { ...@@ -54,7 +54,7 @@ class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> {
const int input_stride = input_height * input_width; const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width; const int output_stride = output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(lite::TargetType::kX86); T* output_data = output->mutable_data<T>(lite::TargetType::kX86);
int hstart, hend; int hstart, hend;
......
...@@ -94,24 +94,12 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { ...@@ -94,24 +94,12 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) {
* This is different from average pooling. So we rewrite the max_pool_grad: * This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor. * MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/ */
//#ifdef PADDLE_WITH_CUDA
// template <typename PoolProcess, typename T>
// class Pool2dDirectCUDAFunctor {
// public:
// void operator()(const T* input, const std::vector<int>& input_shape,
// const std::vector<int>& output_shape,
// const std::vector<int>& ksize,
// const std::vector<int>& strides,
// const std::vector<int>& paddings, PoolProcess pool_compute,
// bool exclusive, T* output, cudaStream_t stream);
//};
//#endif
template <lite::TargetType Target, typename PoolProcess, typename T> template <lite::TargetType Target, typename PoolProcess, typename T>
class Pool2dFunctor { class Pool2dFunctor {
public: public:
void operator()(const lite::Context<Target>& context, void operator()(const lite::Context<Target>& context,
const lite::Tensor& input, const lite::Tensor* input,
const std::vector<int>& ksize, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
......
...@@ -4,7 +4,6 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li ...@@ -4,7 +4,6 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li
# lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
...@@ -15,7 +14,10 @@ add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_de ...@@ -15,7 +14,10 @@ add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_de
# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) # lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
# lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) # lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
# lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) # lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
# lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) add_kernel(pool_compute_x86 X86 basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function)
# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} ) # lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} )
add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute) add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute)
...@@ -24,7 +26,6 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp ...@@ -24,7 +26,6 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp
# lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
# lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
# lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
# lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86)
# lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
# lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
...@@ -57,3 +58,7 @@ lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_com ...@@ -57,3 +58,7 @@ lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_com
lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86)
lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86) lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86)
lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86)
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <Eigen/Core>
#include <random> #include <random>
#include <string> #include <string>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "paddle/fluid/framework/eigen.h" #include "lite/core/types.h"
#include "paddle/fluid/framework/operator.h" #include "lite/fluid/eigen.h"
#include "lite/operators/dropout_op.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -28,7 +30,7 @@ namespace x86 { ...@@ -28,7 +30,7 @@ namespace x86 {
template <typename T, template <typename T,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = lite::fluid::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...@@ -37,16 +39,16 @@ class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -37,16 +39,16 @@ class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<operators::DropoutParam>(); auto& param = *param_.get_mutable<operators::DropoutParam>();
const auto* x_data = param.x->data<T>(); const auto* x_data = param.x->data<T>();
auto* out_data = param.output->template mutable_data<T>(); auto* out_data = param.output->mutable_data<T>();
if (!param.is_test) { if (!param.is_test) {
auto* mask_data = param.mask->template mutable_data<T>(); auto* mask_data = param.mask->mutable_data<T>();
std::random_device rnd; std::random_device rnd;
std::minstd_rand engine; std::minstd_rand engine;
int seed = param.fix_seed ? param.seed : rnd(); int seed = param.fix_seed ? param.seed : rnd();
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1); std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(param.mask->dims().data()); size_t size = param.mask->dims().production();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < param.dropout_prob) { if (dist(engine) < param.dropout_prob) {
mask_data[i] = 0; mask_data[i] = 0;
...@@ -62,13 +64,13 @@ class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -62,13 +64,13 @@ class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
} }
} }
} else { } else {
auto X = EigenMatrix<T>::Reshape(param.x->raw_tensor(), 1); auto X = EigenMatrix<T>::Reshape(*param.x, 1);
auto Y = EigenMatrix<T>::Reshape(param.output->raw_tensor(), 1); auto Y = EigenMatrix<T>::Reshape(*param.output, 1);
auto& place = *platform::CPUDeviceContext().eigen_device();
if (param.dropout_implementation == "upscale_in_train") { if (param.dropout_implementation == "upscale_in_train") {
Y.device(place) = X; Y.device(lite::fluid::EigenDeviceType<lite::TargetType::kX86>()) = X;
} else { } else {
Y.device(place) = X * static_cast<T>(1.0f - param.dropout_prob); Y.device(lite::fluid::EigenDeviceType<lite::TargetType::kX86>()) =
X * static_cast<T>(1.0f - param.dropout_prob);
} }
} }
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "lite/kernels/x86/dropout_compute.h" #include "lite/kernels/x86/dropout_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -60,7 +62,9 @@ TEST(dropout_x86, run_test) { ...@@ -60,7 +62,9 @@ TEST(dropout_x86, run_test) {
param.is_test = true; param.is_test = true;
param.fix_seed = true; param.fix_seed = true;
param.output = &out; param.output = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
dropout.SetContext(std::move(ctx));
dropout.SetParam(param); dropout.SetParam(param);
dropout.Run(); dropout.Run();
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#pragma once #pragma once
#include <Eigen/Core> #include <Eigen/Core>
#include "lite/backends/x86/math/math_function.h"
#include "lite/backends/x86/math/pooling.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/types.h" #include "lite/core/types.h"
#include "paddle/fluid/framework/eigen.h" #include "lite/fluid/eigen.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -31,6 +31,7 @@ class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -31,6 +31,7 @@ class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public: public:
using param_t = operators::PoolParam; using param_t = operators::PoolParam;
void Run() override { void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
if (param.global_pooling) { if (param.global_pooling) {
for (size_t i = 0; i < param.ksize.size(); ++i) { for (size_t i = 0; i < param.ksize.size(); ++i) {
...@@ -41,37 +42,37 @@ class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -41,37 +42,37 @@ class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
switch (param.ksize.size()) { switch (param.ksize.size()) {
case 2: { case 2: {
if (param.pooling_type == "max") { if (param.pooling_type == "max") {
paddle::operators::math::Pool2dFunctor< paddle::lite::x86::math::Pool2dFunctor<
platform::CPUDeviceContext, lite::TargetType::kX86,
paddle::operators::math::MaxPool<T>, paddle::lite::x86::math::MaxPool<T>,
T> T>
pool2d_forward; pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process; paddle::lite::x86::math::MaxPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), pool2d_forward(context,
param.x->raw_tensor(), param.x,
param.ksize, param.ksize,
param.strides, param.strides,
param.paddings, param.paddings,
pool_process, pool_process,
true, true,
false, false,
&(param.output->raw_tensor())); param.output);
} else if (param.pooling_type == "avg") { } else if (param.pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor< paddle::lite::x86::math::Pool2dFunctor<
platform::CPUDeviceContext, lite::TargetType::kX86,
paddle::operators::math::AvgPool<T>, paddle::lite::x86::math::AvgPool<T>,
T> T>
pool2d_forward; pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process; paddle::lite::x86::math::AvgPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), pool2d_forward(context,
param.x->raw_tensor(), param.x,
param.ksize, param.ksize,
param.strides, param.strides,
param.paddings, param.paddings,
pool_process, pool_process,
param.exclusive, param.exclusive,
param.adaptive, param.adaptive,
&(param.output->raw_tensor())); param.output);
} }
} break; } break;
case 3: { case 3: {
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "lite/kernels/x86/pool_compute.h" #include "lite/kernels/x86/pool_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -61,7 +63,9 @@ TEST(pool2d_x86, run_test) { ...@@ -61,7 +63,9 @@ TEST(pool2d_x86, run_test) {
param.paddings = {0, 0}; param.paddings = {0, 0};
param.ksize = {2, 2}; param.ksize = {2, 2};
param.pooling_type = "max"; param.pooling_type = "max";
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
pool2d.SetContext(std::move(ctx));
pool2d.SetParam(param); pool2d.SetParam(param);
pool2d.Run(); pool2d.Run();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
REGISTER_LITE_KERNEL(transpose,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::TransposeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(transpose2,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::Transpose2Compute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <vector>
#include "lite/backends/x86/math/math_function.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/transpose_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <lite::TargetType Target, typename T>
inline void TransCompute(const int dim,
const lite::Context<Target>& context,
const lite::Tensor& in,
lite::Tensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 1:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 1> trans1;
trans1(context, in, out, axis);
break;
case 2:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 2> trans2;
trans2(context, in, out, axis);
break;
case 3:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 3> trans3;
trans3(context, in, out, axis);
break;
case 4:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 4> trans4;
trans4(context, in, out, axis);
break;
case 5:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 5> trans5;
trans5(context, in, out, axis);
break;
case 6:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 6> trans6;
trans6(context, in, out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
}
template <typename T>
class TransposeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto* x = param.x;
auto* out = param.output;
out->mutable_data<T>();
int ndims = param.axis.size();
auto& context = ctx_->As<X86Context>();
TransCompute<lite::TargetType::kX86, T>(
ndims, context, *x, out, param.axis);
}
virtual ~TransposeCompute() = default;
};
template <typename T>
class Transpose2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto* x = param.x;
auto* out = param.output;
out->mutable_data<T>();
int ndims = param.axis.size();
auto& context = ctx_->As<X86Context>();
TransCompute<lite::TargetType::kX86, T>(
ndims, context, *x, out, param.axis);
}
virtual ~Transpose2Compute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
// transpose
TEST(transpose_x86, retrive_op) {
auto transpose =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose");
ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front());
}
TEST(transpose_x86, init) {
lite::kernels::x86::TransposeCompute<float> transpose;
ASSERT_EQ(transpose.precision(), PRECISION(kFloat));
ASSERT_EQ(transpose.target(), TARGET(kX86));
}
TEST(transpose_x86, run_test) {
lite::Tensor x;
lite::Tensor out;
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5, 4});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
// TransposeCompute transpose;
TransposeCompute<float> transpose;
operators::TransposeParam param;
param.x = &x;
param.output = &out;
std::vector<int> axis({0, 2, 1});
param.axis = axis;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
transpose.SetContext(std::move(ctx));
transpose.SetParam(param);
transpose.Run();
for (int j = 0; j < out.dims().production(); ++j) {
// EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
LOG(INFO) << out_data[j];
}
}
// transpose2
TEST(transpose2_x86, retrive_op) {
auto transpose2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose2");
ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front());
}
TEST(transpose2_x86, init) {
lite::kernels::x86::Transpose2Compute<float> transpose2;
ASSERT_EQ(transpose2.precision(), PRECISION(kFloat));
ASSERT_EQ(transpose2.target(), TARGET(kX86));
}
TEST(transpose2_x86, run_test) {
lite::Tensor x;
lite::Tensor xshape;
lite::Tensor out;
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5, 4});
out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({3, 4, 5});
xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto xshape_data = xshape.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
xshape_data[i] = static_cast<float>(i);
}
// Transpose2Compute transpose2;
Transpose2Compute<float> transpose2;
operators::TransposeParam param;
param.x = &x;
param.output = &out;
param.xshape = &xshape;
std::vector<int> axis({0, 2, 1});
param.axis = axis;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
transpose2.SetContext(std::move(ctx));
transpose2.SetParam(param);
transpose2.Run();
for (int j = 0; j < out.dims().production(); ++j) {
LOG(INFO) << out_data[j];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(transpose, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose2, kX86, kFloat, kNCHW, def);
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/operators/dropout_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
...@@ -20,16 +21,12 @@ namespace paddle { ...@@ -20,16 +21,12 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class DropoutOpLite : public OpLite { bool DropoutOp::CheckShape() const {
public:
explicit DropoutOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
return true; return true;
} }
bool InferShape() const override { bool DropoutOp::InferShape() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
param_.output->Resize(x_dims); param_.output->Resize(x_dims);
if (param_.is_test == false) { if (param_.is_test == false) {
...@@ -38,11 +35,10 @@ class DropoutOpLite : public OpLite { ...@@ -38,11 +35,10 @@ class DropoutOpLite : public OpLite {
// share LoD // share LoD
// param_.output->set_lod(param_.input->lod()); // param_.output->set_lod(param_.input->lod());
return true; return true;
} }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one. bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
auto Mask = op_desc.Output("Mask").front(); auto Mask = op_desc.Output("Mask").front();
...@@ -63,16 +59,10 @@ class DropoutOpLite : public OpLite { ...@@ -63,16 +59,10 @@ class DropoutOpLite : public OpLite {
param_.dropout_implementation = param_.dropout_implementation =
op_desc.GetAttr<std::string>("dropout_implementation"); op_desc.GetAttr<std::string>("dropout_implementation");
return true; return true;
} }
std::string DebugString() const override { return "dropout"; }
private:
mutable DropoutParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOpLite); REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class DropoutOp : public OpLite {
public:
explicit DropoutOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override;
std::string DebugString() const override { return "dropout"; }
private:
mutable DropoutParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -87,6 +87,10 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -87,6 +87,10 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims"); param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
if (op_desc.HasAttr("activation_type")) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
}
// For Int8 // For Int8
if (op_desc.HasAttr("enable_int8")) { if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8"); param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
......
...@@ -83,6 +83,7 @@ struct FcParam { ...@@ -83,6 +83,7 @@ struct FcParam {
lite::Tensor* output{nullptr}; lite::Tensor* output{nullptr};
lite::DDim in_mat_dims; lite::DDim in_mat_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
std::string activation_type{""};
// for int8 // for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; };
...@@ -323,6 +324,8 @@ struct SplitParam { ...@@ -323,6 +324,8 @@ struct SplitParam {
struct TransposeParam { struct TransposeParam {
const lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
lite::Tensor* xshape{};
std::vector<int> axis; std::vector<int> axis;
bool use_mkldnn{false}; bool use_mkldnn{false};
std::string data_format{"AnyLayout"}; std::string data_format{"AnyLayout"};
......
...@@ -154,6 +154,10 @@ bool Transpose2Op::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -154,6 +154,10 @@ bool Transpose2Op::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
if (op_desc.HasAttr("data_format")) { if (op_desc.HasAttr("data_format")) {
param_.data_format = op_desc.GetAttr<std::string>("data_format"); param_.data_format = op_desc.GetAttr<std::string>("data_format");
} }
if (op_desc.HasOutput("XShape")) {
auto xshape_var = scope->FindVar(op_desc.Output("XShape").front());
param_.xshape = xshape_var->GetMutable<lite::Tensor>();
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册