未验证 提交 aa507f9b 编写于 作者: L liu zhengxi 提交者: GitHub

Enable pool2d, dropout, transpose and transpose2 op on x86 (#2226)

* enable pool2d op on x86 and add its unit tests, test=develop

* enable dropout op and add its unit tests, test=develop

* add tranpose, transpose2 op and add their unit tests, test=develop
上级 76e74ef1
......@@ -30,7 +30,7 @@ template <typename PoolProcess, typename T>
class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> {
public:
void operator()(const lite::X86Context& context,
const lite::Tensor& input,
const lite::Tensor* input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
......@@ -38,9 +38,9 @@ class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> {
bool exclusive,
bool adaptive,
lite::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
......@@ -54,7 +54,7 @@ class Pool2dFunctor<lite::TargetType::kX86, PoolProcess, T> {
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(lite::TargetType::kX86);
int hstart, hend;
......
......@@ -94,24 +94,12 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) {
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
//#ifdef PADDLE_WITH_CUDA
// template <typename PoolProcess, typename T>
// class Pool2dDirectCUDAFunctor {
// public:
// void operator()(const T* input, const std::vector<int>& input_shape,
// const std::vector<int>& output_shape,
// const std::vector<int>& ksize,
// const std::vector<int>& strides,
// const std::vector<int>& paddings, PoolProcess pool_compute,
// bool exclusive, T* output, cudaStream_t stream);
//};
//#endif
template <lite::TargetType Target, typename PoolProcess, typename T>
class Pool2dFunctor {
public:
void operator()(const lite::Context<Target>& context,
const lite::Tensor& input,
const lite::Tensor* input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
......
......@@ -4,7 +4,6 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li
# lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
......@@ -15,7 +14,10 @@ add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_de
# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
# lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
# lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
# lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
add_kernel(pool_compute_x86 X86 basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function)
# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} )
add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute)
......@@ -24,7 +26,6 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp
# lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
# lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
# lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
# lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86)
# lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
# lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
......@@ -57,3 +58,7 @@ lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_com
lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86)
lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86)
lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86)
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
......@@ -13,12 +13,14 @@
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <random>
#include <string>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "lite/core/types.h"
#include "lite/fluid/eigen.h"
#include "lite/operators/dropout_op.h"
namespace paddle {
namespace lite {
......@@ -28,7 +30,7 @@ namespace x86 {
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using EigenMatrix = lite::fluid::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
......@@ -37,16 +39,16 @@ class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<operators::DropoutParam>();
const auto* x_data = param.x->data<T>();
auto* out_data = param.output->template mutable_data<T>();
auto* out_data = param.output->mutable_data<T>();
if (!param.is_test) {
auto* mask_data = param.mask->template mutable_data<T>();
auto* mask_data = param.mask->mutable_data<T>();
std::random_device rnd;
std::minstd_rand engine;
int seed = param.fix_seed ? param.seed : rnd();
engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(param.mask->dims().data());
size_t size = param.mask->dims().production();
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < param.dropout_prob) {
mask_data[i] = 0;
......@@ -62,13 +64,13 @@ class DropoutCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
}
}
} else {
auto X = EigenMatrix<T>::Reshape(param.x->raw_tensor(), 1);
auto Y = EigenMatrix<T>::Reshape(param.output->raw_tensor(), 1);
auto& place = *platform::CPUDeviceContext().eigen_device();
auto X = EigenMatrix<T>::Reshape(*param.x, 1);
auto Y = EigenMatrix<T>::Reshape(*param.output, 1);
if (param.dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
Y.device(lite::fluid::EigenDeviceType<lite::TargetType::kX86>()) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - param.dropout_prob);
Y.device(lite::fluid::EigenDeviceType<lite::TargetType::kX86>()) =
X * static_cast<T>(1.0f - param.dropout_prob);
}
}
}
......
......@@ -15,6 +15,8 @@
#include "lite/kernels/x86/dropout_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
......@@ -60,7 +62,9 @@ TEST(dropout_x86, run_test) {
param.is_test = true;
param.fix_seed = true;
param.output = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
dropout.SetContext(std::move(ctx));
dropout.SetParam(param);
dropout.Run();
......
......@@ -14,12 +14,12 @@
#pragma once
#include <Eigen/Core>
#include "lite/backends/x86/math/math_function.h"
#include "lite/backends/x86/math/pooling.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
......@@ -31,6 +31,7 @@ class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::PoolParam;
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<param_t>();
if (param.global_pooling) {
for (size_t i = 0; i < param.ksize.size(); ++i) {
......@@ -41,37 +42,37 @@ class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
switch (param.ksize.size()) {
case 2: {
if (param.pooling_type == "max") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext,
paddle::operators::math::MaxPool<T>,
paddle::lite::x86::math::Pool2dFunctor<
lite::TargetType::kX86,
paddle::lite::x86::math::MaxPool<T>,
T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(),
param.x->raw_tensor(),
paddle::lite::x86::math::MaxPool<T> pool_process;
pool2d_forward(context,
param.x,
param.ksize,
param.strides,
param.paddings,
pool_process,
true,
false,
&(param.output->raw_tensor()));
param.output);
} else if (param.pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext,
paddle::operators::math::AvgPool<T>,
paddle::lite::x86::math::Pool2dFunctor<
lite::TargetType::kX86,
paddle::lite::x86::math::AvgPool<T>,
T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(),
param.x->raw_tensor(),
paddle::lite::x86::math::AvgPool<T> pool_process;
pool2d_forward(context,
param.x,
param.ksize,
param.strides,
param.paddings,
pool_process,
param.exclusive,
param.adaptive,
&(param.output->raw_tensor()));
param.output);
}
} break;
case 3: {
......
......@@ -15,6 +15,8 @@
#include "lite/kernels/x86/pool_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
......@@ -61,7 +63,9 @@ TEST(pool2d_x86, run_test) {
param.paddings = {0, 0};
param.ksize = {2, 2};
param.pooling_type = "max";
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
pool2d.SetContext(std::move(ctx));
pool2d.SetParam(param);
pool2d.Run();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
REGISTER_LITE_KERNEL(transpose,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::TransposeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(transpose2,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::Transpose2Compute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <vector>
#include "lite/backends/x86/math/math_function.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/transpose_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <lite::TargetType Target, typename T>
inline void TransCompute(const int dim,
const lite::Context<Target>& context,
const lite::Tensor& in,
lite::Tensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 1:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 1> trans1;
trans1(context, in, out, axis);
break;
case 2:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 2> trans2;
trans2(context, in, out, axis);
break;
case 3:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 3> trans3;
trans3(context, in, out, axis);
break;
case 4:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 4> trans4;
trans4(context, in, out, axis);
break;
case 5:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 5> trans5;
trans5(context, in, out, axis);
break;
case 6:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 6> trans6;
trans6(context, in, out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
}
template <typename T>
class TransposeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto* x = param.x;
auto* out = param.output;
out->mutable_data<T>();
int ndims = param.axis.size();
auto& context = ctx_->As<X86Context>();
TransCompute<lite::TargetType::kX86, T>(
ndims, context, *x, out, param.axis);
}
virtual ~TransposeCompute() = default;
};
template <typename T>
class Transpose2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto* x = param.x;
auto* out = param.output;
out->mutable_data<T>();
int ndims = param.axis.size();
auto& context = ctx_->As<X86Context>();
TransCompute<lite::TargetType::kX86, T>(
ndims, context, *x, out, param.axis);
}
virtual ~Transpose2Compute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
// transpose
TEST(transpose_x86, retrive_op) {
auto transpose =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose");
ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front());
}
TEST(transpose_x86, init) {
lite::kernels::x86::TransposeCompute<float> transpose;
ASSERT_EQ(transpose.precision(), PRECISION(kFloat));
ASSERT_EQ(transpose.target(), TARGET(kX86));
}
TEST(transpose_x86, run_test) {
lite::Tensor x;
lite::Tensor out;
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5, 4});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
// TransposeCompute transpose;
TransposeCompute<float> transpose;
operators::TransposeParam param;
param.x = &x;
param.output = &out;
std::vector<int> axis({0, 2, 1});
param.axis = axis;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
transpose.SetContext(std::move(ctx));
transpose.SetParam(param);
transpose.Run();
for (int j = 0; j < out.dims().production(); ++j) {
// EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
LOG(INFO) << out_data[j];
}
}
// transpose2
TEST(transpose2_x86, retrive_op) {
auto transpose2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose2");
ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front());
}
TEST(transpose2_x86, init) {
lite::kernels::x86::Transpose2Compute<float> transpose2;
ASSERT_EQ(transpose2.precision(), PRECISION(kFloat));
ASSERT_EQ(transpose2.target(), TARGET(kX86));
}
TEST(transpose2_x86, run_test) {
lite::Tensor x;
lite::Tensor xshape;
lite::Tensor out;
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5, 4});
out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({3, 4, 5});
xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto xshape_data = xshape.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
xshape_data[i] = static_cast<float>(i);
}
// Transpose2Compute transpose2;
Transpose2Compute<float> transpose2;
operators::TransposeParam param;
param.x = &x;
param.output = &out;
param.xshape = &xshape;
std::vector<int> axis({0, 2, 1});
param.axis = axis;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
transpose2.SetContext(std::move(ctx));
transpose2.SetParam(param);
transpose2.Run();
for (int j = 0; j < out.dims().production(); ++j) {
LOG(INFO) << out_data[j];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(transpose, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose2, kX86, kFloat, kNCHW, def);
......@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/dropout_op.h"
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
......@@ -20,16 +21,12 @@ namespace paddle {
namespace lite {
namespace operators {
class DropoutOpLite : public OpLite {
public:
explicit DropoutOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
bool DropoutOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
return true;
}
}
bool InferShape() const override {
bool DropoutOp::InferShape() const {
const auto x_dims = param_.x->dims();
param_.output->Resize(x_dims);
if (param_.is_test == false) {
......@@ -38,11 +35,10 @@ class DropoutOpLite : public OpLite {
// share LoD
// param_.output->set_lod(param_.input->lod());
return true;
}
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
auto Mask = op_desc.Output("Mask").front();
......@@ -63,16 +59,10 @@ class DropoutOpLite : public OpLite {
param_.dropout_implementation =
op_desc.GetAttr<std::string>("dropout_implementation");
return true;
}
std::string DebugString() const override { return "dropout"; }
private:
mutable DropoutParam param_;
};
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOpLite);
REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class DropoutOp : public OpLite {
public:
explicit DropoutOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override;
std::string DebugString() const override { return "dropout"; }
private:
mutable DropoutParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -87,6 +87,10 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
if (op_desc.HasAttr("activation_type")) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
}
// For Int8
if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
......
......@@ -83,6 +83,7 @@ struct FcParam {
lite::Tensor* output{nullptr};
lite::DDim in_mat_dims;
int in_num_col_dims{1};
std::string activation_type{""};
// for int8
WITH_INT8_CONFIG
};
......@@ -323,6 +324,8 @@ struct SplitParam {
struct TransposeParam {
const lite::Tensor* x{};
lite::Tensor* output{};
lite::Tensor* xshape{};
std::vector<int> axis;
bool use_mkldnn{false};
std::string data_format{"AnyLayout"};
......
......@@ -154,6 +154,10 @@ bool Transpose2Op::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
if (op_desc.HasAttr("data_format")) {
param_.data_format = op_desc.GetAttr<std::string>("data_format");
}
if (op_desc.HasOutput("XShape")) {
auto xshape_var = scope->FindVar(op_desc.Output("XShape").front());
param_.xshape = xshape_var->GetMutable<lite::Tensor>();
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册