提交 97882cfa 编写于 作者: L liu zhengxi 提交者: GitHub

delete useless code for x86 platform (#2535)

上级 6b393f96
...@@ -12,20 +12,6 @@ ...@@ -12,20 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
......
...@@ -54,29 +54,6 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -54,29 +54,6 @@ class MeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
virtual ~MeanCompute() = default; virtual ~MeanCompute() = default;
}; };
template <typename T>
class MeanGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MeanGradParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1);
CHECK(context.x86_device_context());
param.X_grad->template mutable_data<T>();
T x_grad_size = static_cast<T>(param.X_grad->raw_tensor().numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_grad_size));
EigenVector<T>::Flatten(param.X_grad->raw_tensor())
.device(*(context.x86_device_context()->eigen_device())) =
(EigenVector<T>::From(param.Out_grad->raw_tensor()) / x_grad_size)
.broadcast(bcast);
}
virtual ~MeanGradCompute() = default;
};
} // namespace x86 } // namespace x86
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -93,16 +70,3 @@ REGISTER_LITE_KERNEL(mean, ...@@ -93,16 +70,3 @@ REGISTER_LITE_KERNEL(mean,
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(mean_grad,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::MeanGradCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput(paddle::framework::GradVarName("Out"),
{LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput(paddle::framework::GradVarName("X"),
{LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
...@@ -24,21 +24,3 @@ REGISTER_LITE_KERNEL(mul, ...@@ -24,21 +24,3 @@ REGISTER_LITE_KERNEL(mul,
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
// #ifdef LITE_WITH_TRAIN
// REGISTER_LITE_KERNEL(mul_grad,
// kX86,
// kFloat,
// kNCHW,
// paddle::lite::kernels::x86::MulGradCompute<float>,
// def)
// .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
// .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
// .BindInput(paddle::framework::GradVarName("Out"),
// {LiteType::GetTensorTy(TARGET(kX86))})
// .BindOutput(paddle::framework::GradVarName("X"),
// {LiteType::GetTensorTy(TARGET(kX86))})
// .BindOutput(paddle::framework::GradVarName("Y"),
// {LiteType::GetTensorTy(TARGET(kX86))})
// .Finalize();
// #endif
...@@ -81,78 +81,6 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -81,78 +81,6 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
virtual ~MulCompute() = default; virtual ~MulCompute() = default;
}; };
#ifdef LITE_WITH_TRAIN
template <typename T>
class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
void Run() override {
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::MulGradParam>();
CHECK(context.x86_device_context());
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
Tensor x_matrix, y_matrix;
if (x->dims().size() > 2) {
x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
} else {
x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize(
{framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]});
auto* dx = &param.x_grad->raw_tensor();
auto* dy = &param.y_grad->raw_tensor();
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
*context.x86_device_context());
if (dx) {
// dx->mutable_data<T>(context.x86_device_context->GetPlace());
param.x_grad->template mutable_data<T>();
Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
*dx, param.x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
// dy->yutable_data<T>(context.x86_device_context->GetPlace());
param.y_grad->template mutable_data<T>();
Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
*dy, param.y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
virtual ~MulGradCompute() = default;
};
#endif
} // namespace x86 } // namespace x86
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/relu_compute.h"
REGISTER_LITE_KERNEL(relu,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ReluCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/relu_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto n = param.X->dims().production();
const float* input = param.X->data<float>();
float* output = param.Out->mutable_data<float>();
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
}
virtual ~ReluCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册