提交 6d7e40a9 编写于 作者: X xiaogang 提交者: GitHub

feat: add elementwise_grad op (#3246)

上级 5045d394
......@@ -266,6 +266,72 @@ void elementwise_add_relu_broadcast<float>(const float* dinx,
}
}
template <>
void elementwise_add_grad<float>(const float* dout_grad,
float* x_grad,
int num) {
int cnt = num >> 4;
int remain = num & 0x0f;
#pragma omp parallel for
for (int i = 0; i < cnt; ++i) {
const float* out_data = dout_grad + 16 * i;
float* x_data = x_grad + 16 * i;
float32x4_t din0 = vld1q_f32(out_data);
float32x4_t din1 = vld1q_f32(out_data + 4);
float32x4_t din2 = vld1q_f32(out_data + 8);
float32x4_t din3 = vld1q_f32(out_data + 12);
vst1q_f32(x_data, din0);
vst1q_f32(x_data + 4, din1);
vst1q_f32(x_data + 8, din2);
vst1q_f32(x_data + 12, din3);
}
if (remain > 0) {
const float* out_data = dout_grad + 16 * cnt;
float* x_data = x_grad + 16 * cnt;
for (int i = 0; i < remain; ++i) {
x_data[i] = out_data[i];
}
}
}
// we assume that y_data numel less than x_data, otherwise, call this function
// by change x_grad and y_grad position
template <>
void elementwise_add_grad_broadcast<float>(const float* dout_grad,
float* x_grad,
float* y_grad,
int pre,
int n,
int post) {
if (x_grad) {
elementwise_add_grad(dout_grad, x_grad, pre * n * post);
}
if (y_grad) {
memset(y_grad, 0, n * sizeof(float));
#pragma omp parallel for
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
float sum = 0;
int cnt = post >> 2;
int remain = post & 0x03;
const float* out_data = dout_grad + (i * n + j) * post;
float32x4_t sum_v = vdupq_n_f32(0);
for (int ci = 0; ci < cnt; ++ci) {
float32x4_t din = vld1q_f32(out_data + 4 * ci);
sum_v = vaddq_f32(sum_v, din);
}
out_data += 4 * cnt;
for (int ci = 0; ci < remain; ++ci) {
sum += out_data[ci];
}
float32x2_t high = vget_high_f32(sum_v);
float32x2_t low = vget_low_f32(sum_v);
sum += vget_lane_f32(high, 0) + vget_lane_f32(high, 1) +
vget_lane_f32(low, 0) + vget_lane_f32(low, 1);
y_grad[j] += sum;
}
}
}
}
template <>
void elementwise_sub<float>(const float* dinx,
const float* diny,
......@@ -510,6 +576,84 @@ void elementwise_sub_relu_broadcast<float>(const float* dinx,
}
}
}
// we assume the formula is x-y
template <>
void elementwise_sub_grad<float>(const float* dout_grad,
float* x_grad,
float* y_grad,
int num) {
if (x_grad) {
elementwise_add_grad(dout_grad, x_grad, num);
}
if (y_grad) {
int cnt = num >> 4;
int remain = num & 0x0f;
float32x4_t minus = vdupq_n_f32(-1);
#pragma omp parallel for
for (int i = 0; i < cnt; ++i) {
const float* out_data = dout_grad + 16 * i;
float* y_data = y_grad + 16 * i;
float32x4_t din0 = vld1q_f32(out_data);
float32x4_t din1 = vld1q_f32(out_data + 4);
float32x4_t din2 = vld1q_f32(out_data + 8);
float32x4_t din3 = vld1q_f32(out_data + 12);
din0 = vmulq_f32(din0, minus);
din1 = vmulq_f32(din1, minus);
din2 = vmulq_f32(din2, minus);
din3 = vmulq_f32(din3, minus);
vst1q_f32(y_data, din0);
vst1q_f32(y_data + 4, din1);
vst1q_f32(y_data + 8, din2);
vst1q_f32(y_data + 12, din3);
}
if (remain > 0) {
const float* out_data = dout_grad + 16 * cnt;
float* y_data = y_grad + 16 * cnt;
for (int i = 0; i < remain; ++i) {
y_data[i] = -out_data[i];
}
}
}
}
// we assume that y_data numel less than x_data, otherwise, call this function
// by change x_grad and y_grad position
template <>
void elementwise_sub_grad_broadcast<float>(const float* dout_grad,
float* x_grad,
float* y_grad,
int pre,
int n,
int post) {
if (x_grad) {
elementwise_add_grad(dout_grad, x_grad, pre * n * post);
}
if (y_grad) {
memset(y_grad, 0, n * sizeof(float));
#pragma omp parallel for
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
float sum = 0;
int cnt = post << 2;
int remain = post & 0x03;
const float* out_data = dout_grad + (i * n + j) * post;
float32x4_t sum_v = vdupq_n_f32(0);
for (int ci = 0; ci < cnt; ++ci) {
float32x4_t din = vld1q_f32(out_data + 4 * ci);
sum_v = vaddq_f32(sum_v, din);
}
out_data += 4 * cnt;
for (int ci = 0; ci < remain; ++ci) {
sum -= out_data[ci];
}
float32x2_t high = vget_high_f32(sum_v);
float32x2_t low = vget_low_f32(sum_v);
sum -= vget_lane_f32(high, 0) + vget_lane_f32(high, 1) +
vget_lane_f32(low, 0) + vget_lane_f32(low, 1);
y_grad[j] += sum;
}
}
}
}
template <>
void elementwise_mul<float>(const float* dinx,
......
......@@ -183,6 +183,13 @@ template <typename T>
void elementwise_add_relu_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T>
void elementwise_add_grad(const T* dout, T* dinx, int num);
template <typename T>
void elementwise_add_grad_broadcast(
const T* dout_grad, T* x_grad, T* y_grad, int pre, int n, int post);
template <typename T>
void elementwise_sub(const T* dinx, const T* diny, T* dout, int num);
......@@ -197,6 +204,13 @@ template <typename T>
void elementwise_sub_relu_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T>
void elementwise_sub_grad(const T* dout, T* dinx, T* diny, int num);
template <typename T>
void elementwise_sub_grad_broadcast(
const T* dout_grad, T* x_grad, T* y_grad, int pre, int n, int post);
template <typename T>
void elementwise_mul(const T* dinx, const T* diny, T* dout, int num);
......
......@@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de
if(LITE_WITH_TRAIN)
add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
endif()
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/elementwise_grad_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
inline DDim trim_trailing_singular_dims(const DDim& dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
std::vector<int64_t> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return DDim();
}
return DDim(trim_dims);
}
inline bool is_broadcast(const DDim& x_dims,
const DDim& y_dims,
int axis,
int* pre,
int* n,
int* post) {
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
DDim y_dim_trim = trim_trailing_singular_dims(y_dims);
axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis;
if (x_dims.size() == y_dim_trim.size()) {
return false;
}
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dim_trim.size(); ++i) {
CHECK_EQ(x_dims[i + axis], y_dim_trim[i])
<< "Broadcast dimension mismatch.";
(*n) *= y_dim_trim[i];
}
for (int i = axis + y_dim_trim.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
return true;
}
void ElementwiseAddGradCompute::Run() {
auto& param = Param<operators::ElementwiseGradParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_grad_data = param.OutGrad->data<float>();
float* x_grad_data = param.XGrad->mutable_data<float>();
float* y_grad_data = param.YGrad->mutable_data<float>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_grad_broadcast(
out_grad_data, y_grad_data, x_grad_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_grad_broadcast(
out_grad_data, x_grad_data, y_grad_data, pre, n, post);
} else {
lite::arm::math::elementwise_add_grad(
out_grad_data, x_grad_data, x_dims.production());
lite::arm::math::elementwise_add_grad(
out_grad_data, y_grad_data, y_dims.production());
}
}
void ElementwiseSubGradCompute::Run() {
auto& param = Param<operators::ElementwiseGradParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_data = param.OutGrad->data<float>();
float* x_grad_data = param.XGrad->mutable_data<float>();
float* y_grad_data = param.YGrad->mutable_data<float>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div grad don't support x_dims size < y_dims size";
}
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_grad_broadcast(
out_data, x_grad_data, y_grad_data, pre, n, post);
} else {
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, x_dims.production());
}
}
template <typename T, PrecisionType PType>
void ElementwiseMulGradCompute<T, PType>::Run() {
LOG(FATAL) << "elementwise mul_grad not implement yet";
}
void ElementwiseMaxGradCompute::Run() {
LOG(FATAL) << "elementwise max_grad not implement yet";
}
void ElementwiseDivGradCompute::Run() {
LOG(FATAL) << "elementwise div_grad not implement yet";
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
using elementwise_mul_grad_float =
paddle::lite::kernels::arm::ElementwiseMulGradCompute<float,
PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(elementwise_add_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseAddGradCompute,
def)
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_sub_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseSubGradCompute,
def)
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_div_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseDivGradCompute,
def)
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
elementwise_mul_grad, kARM, kFloat, kNCHW, elementwise_mul_grad_float, def)
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_max_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseMaxGradCompute,
def)
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ElementwiseAddGradCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseAddGradCompute() = default;
};
class ElementwiseSubGradCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseSubGradCompute() = default;
};
template <typename T, PrecisionType PType>
class ElementwiseMulGradCompute : public KernelLite<TARGET(kARM), PType> {
public:
void Run() override;
virtual ~ElementwiseMulGradCompute() = default;
};
class ElementwiseMaxGradCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseMaxGradCompute() = default;
};
class ElementwiseDivGradCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseDivGradCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if (LITE_WITH_TRAIN)
add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS})
add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS})
add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS})
add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS})
add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS})
endif()
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/elementwise_grad_ops.h"
#include <algorithm>
#include <cmath>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ElementwiseGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.XGrad);
CHECK_OR_FALSE(param_.YGrad);
CHECK_OR_FALSE(param_.OutGrad);
return true;
}
bool ElementwiseGradOp::InferShape() const {
auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims();
param_.XGrad->Resize(x_dim);
param_.YGrad->Resize(y_dim);
return true;
}
bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto Y_name = opdesc.Input("Y").front();
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Input("Out@Grad").front();
auto x_grad_name = opdesc.Output("X@Grad").front();
auto y_grad_name = opdesc.Output("Y@Grad").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.XGrad = GetMutableVar<lite::Tensor>(scope, x_grad_name);
param_.YGrad = GetMutableVar<lite::Tensor>(scope, y_grad_name);
param_.OutGrad = GetVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(elementwise_grad_sub,
paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_add,
paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_mul,
paddle::lite::operators::ElementwiseGradOp);
REGISTER_LITE_OP(elementwise_grad_max,
paddle::lite::operators::ElementwiseGradOp);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class ElementwiseGradOp : public OpLite {
public:
explicit ElementwiseGradOp(const std::string& op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "elementwise_grad_op"; }
private:
mutable operators::ElementwiseGradParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -387,10 +387,11 @@ struct ElementwiseParam {
};
struct ElementwiseGradParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
const lite::Tensor* Out_grad{};
lite::Tensor* X_grad{};
lite::Tensor* Y_grad{};
const lite::Tensor* OutGrad{};
lite::Tensor* XGrad{};
lite::Tensor* YGrad{};
int axis{-1}; // for broadcasting.
};
......
......@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA)
if (LITE_WITH_TRAIN)
lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/elementwise_grad_compute.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/elementwise_compute.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
using param_t = operators::ElementwiseParam;
using grad_param_t = operators::ElementwiseGradParam;
using kernel_add_t = ElementwiseAddCompute;
using grad_kernel_add_t = ElementwiseAddGradCompute;
using kernel_sub_t = ElementwiseSubCompute;
using grad_kernel_sub_t = ElementwiseSubGradCompute;
void elementwise_common(grad_param_t& param, // NOLINT
std::vector<float>& out_grad, // NOLINT
std::vector<float>& x_grad, // NOLINT
std::vector<float>& y_grad, // NOLINT
std::string flag) {
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
if (x_dims == y_dims) {
for (int i = 0; i < x_dims.production(); ++i) {
if (flag == "add") {
x_grad[i] = out_grad[i];
y_grad[i] = out_grad[i];
}
if (flag == "sub") {
x_grad[i] = out_grad[i];
y_grad[i] = -out_grad[i];
}
}
} else {
LOG(FATAL) << "unsupport dims";
}
}
class ElementwiseAddGradTester {
public:
explicit ElementwiseAddGradTester(const DDim& x_dims,
const DDim& y_dims,
int axis)
: x_dims_(x_dims), y_dims_(y_dims), axis_(axis) {}
void prepare_kernel() {
std::unique_ptr<KernelContext> ctx1(new KernelContext);
ctx1->As<ARMContext>();
kernel_.SetContext(std::move(ctx1));
std::unique_ptr<KernelContext> ctx3(new KernelContext);
ctx3->As<ARMContext>();
grad_kernel_.SetContext(std::move(ctx3));
}
void run_forward(param_t* param,
kernel_add_t* kernel,
const std::vector<float>& x_vec,
const std::vector<float>& y_vec,
float* out_vec) {
Tensor x;
Tensor y;
Tensor output;
x.Resize(x_dims_);
y.Resize(y_dims_);
output.Resize(DDim(out_dims_));
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = x_vec[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_data[i] = y_vec[i];
}
param->X = &x;
param->Y = &y;
param->Out = &output;
param->axis = axis_;
kernel->SetParam(*param);
kernel->Launch();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < out_dims_.production(); i++) {
out_vec[i] = output_data[i];
}
}
void run_backward(grad_param_t* param,
grad_kernel_add_t* kernel,
const std::vector<float>& x_vec,
const std::vector<float>& y_vec,
const std::vector<float>& out_grad_vec,
float* x_grad_vec,
float* y_grad_vec) {
Tensor x;
Tensor x_grad;
Tensor y;
Tensor y_grad;
Tensor out_grad;
x.Resize(x_dims_);
x_grad.Resize(x_dims_);
y.Resize(y_dims_);
y_grad.Resize(y_dims_);
out_grad.Resize(out_dims_);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* out_grad_data = out_grad.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = x_vec[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_data[i] = y_vec[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
out_grad_data[i] = out_grad_vec[i];
}
param->X = &x;
param->XGrad = &x_grad;
param->Y = &y;
param->YGrad = &y_grad;
param->OutGrad = &out_grad;
param->axis = axis_;
kernel->SetParam(*param);
kernel->Launch();
auto* x_grad_data = x_grad.mutable_data<float>();
auto* y_grad_data = y_grad.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_grad_vec[i] = x_grad_data[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_grad_vec[i] = y_grad_data[i];
}
}
void check_grad(float delta2, float max_grad_delta2) {
std::vector<int64_t> out_shape;
// infer shape
auto x_dim = x_dims_;
auto y_dim = y_dims_;
if (x_dim == y_dim) {
out_dims_ = x_dim;
} else {
int max_dim = (x_dim.size() > y_dim.size() ? x_dim.size() : y_dim.size());
int axis = param_.axis;
axis =
(axis == -1 ? std::abs(static_cast<int>(x_dim.size() - y_dim.size()))
: axis);
std::vector<int64_t> x_dims_array(max_dim);
std::vector<int64_t> y_dims_array(max_dim);
std::vector<int64_t> out_dims_array(max_dim);
if (x_dim.size() > y_dim.size()) {
for (int i = 0; i < axis; ++i) {
y_dims_array[i] = 1;
}
if (axis + y_dim.size() < max_dim) {
for (int i = axis + y_dim.size(); i < max_dim; ++i) {
y_dims_array[i] = 1;
}
}
x_dims_array = x_dim.Vectorize();
for (int i = 0; i < y_dim.size(); ++i) {
y_dims_array[i + axis] = y_dim[i];
}
} else {
for (int i = 0; i < axis; ++i) {
x_dims_array[i] = 1;
}
if (axis + x_dim.size() < max_dim) {
for (int i = axis + x_dim.size(); i < max_dim; ++i) {
x_dims_array[i] = 1;
}
}
y_dims_array = y_dim.Vectorize();
for (int i = 0; i < x_dim.size(); ++i) {
x_dims_array[i + axis] = x_dim[i];
}
}
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] == -1 || y_dims_array[i] == -1) {
out_dims_array[i] = -1;
} else {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
}
}
out_dims_ = DDim(out_dims_array);
}
// infer end
// forward
std::vector<float> x(x_dims_.production());
std::vector<float> y(y_dims_.production());
std::vector<float> out(out_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production());
this->run_forward(&param_, &kernel_, x, y, out.data());
for (int i = 0; i < x_dims_.production(); i++) {
LOG(INFO) << "x_" << i << ": " << x[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
LOG(INFO) << "y_" << i << ": " << y[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
LOG(INFO) << "out_" << i << ": " << out[i];
}
// backward
std::vector<float> out_grad(out_dims_.production());
std::vector<float> x_grad(x_dims_.production());
std::vector<float> y_grad(y_dims_.production());
for (int i = 0; i < out_dims_.production(); i++) {
out_grad[i] = 1.0;
}
this->run_backward(&grad_param_,
&grad_kernel_,
x,
y,
out_grad,
x_grad.data(),
y_grad.data());
for (int i = 0; i < x_grad.size(); i++) {
LOG(INFO) << "x_grad_" << i << ": " << x_grad[i];
}
for (int i = 0; i < y_grad.size(); i++) {
LOG(INFO) << "y_grad_" << i << ": " << y_grad[i];
}
// get numeric gradient
std::vector<float> x_delta(x_dims_.production());
std::vector<float> y_delta(y_dims_.production());
std::vector<float> out_delta(out_dims_.production());
Tensor tensor_x;
Tensor tensor_y;
tensor_x.Resize(x_dims_);
tensor_y.Resize(y_dims_);
grad_param_.X = &tensor_x;
grad_param_.Y = &tensor_y;
elementwise_common(grad_param_, out_grad, x_delta, y_delta, "add");
float max_grad_delta = 0.0005;
for (int i = 0; i < x_dims_.production(); i++) {
EXPECT_NEAR(x_grad[i], x_delta[i], max_grad_delta);
EXPECT_NEAR(y_grad[i], y_delta[i], max_grad_delta);
}
}
private:
DDim x_dims_;
DDim y_dims_;
DDim out_dims_;
int axis_;
kernel_add_t kernel_;
grad_kernel_add_t grad_kernel_;
param_t param_;
grad_param_t grad_param_;
};
class ElementwiseSubGradTester {
public:
explicit ElementwiseSubGradTester(const DDim& x_dims,
const DDim& y_dims,
int axis)
: x_dims_(x_dims), y_dims_(y_dims), axis_(axis) {}
void prepare_kernel() {
std::unique_ptr<KernelContext> ctx1(new KernelContext);
ctx1->As<ARMContext>();
kernel_.SetContext(std::move(ctx1));
std::unique_ptr<KernelContext> ctx3(new KernelContext);
ctx3->As<ARMContext>();
grad_kernel_.SetContext(std::move(ctx3));
}
void run_forward(param_t* param,
kernel_sub_t* kernel,
const std::vector<float>& x_vec,
const std::vector<float>& y_vec,
float* out_vec) {
Tensor x;
Tensor y;
Tensor output;
x.Resize(x_dims_);
y.Resize(y_dims_);
output.Resize(DDim(out_dims_));
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = x_vec[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_data[i] = y_vec[i];
}
param->X = &x;
param->Y = &y;
param->Out = &output;
param->axis = axis_;
kernel->SetParam(*param);
kernel->Launch();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < out_dims_.production(); i++) {
out_vec[i] = output_data[i];
}
}
void run_backward(grad_param_t* param,
grad_kernel_sub_t* kernel,
const std::vector<float>& x_vec,
const std::vector<float>& y_vec,
const std::vector<float>& out_grad_vec,
float* x_grad_vec,
float* y_grad_vec) {
Tensor x;
Tensor x_grad;
Tensor y;
Tensor y_grad;
Tensor out_grad;
x.Resize(x_dims_);
x_grad.Resize(x_dims_);
y.Resize(y_dims_);
y_grad.Resize(y_dims_);
out_grad.Resize(out_dims_);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* out_grad_data = out_grad.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = x_vec[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_data[i] = y_vec[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
out_grad_data[i] = out_grad_vec[i];
}
param->X = &x;
param->XGrad = &x_grad;
param->Y = &y;
param->YGrad = &y_grad;
param->OutGrad = &out_grad;
param->axis = axis_;
kernel->SetParam(*param);
kernel->Launch();
auto* x_grad_data = x_grad.mutable_data<float>();
auto* y_grad_data = y_grad.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_grad_vec[i] = x_grad_data[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_grad_vec[i] = y_grad_data[i];
}
}
void check_grad(float delta2, float max_grad_delta2) {
std::vector<int64_t> out_shape;
// infer shape
auto x_dim = x_dims_;
auto y_dim = y_dims_;
if (x_dim == y_dim) {
out_dims_ = x_dim;
} else {
int max_dim = (x_dim.size() > y_dim.size() ? x_dim.size() : y_dim.size());
int axis = param_.axis;
axis =
(axis == -1 ? std::abs(static_cast<int>(x_dim.size() - y_dim.size()))
: axis);
std::vector<int64_t> x_dims_array(max_dim);
std::vector<int64_t> y_dims_array(max_dim);
std::vector<int64_t> out_dims_array(max_dim);
if (x_dim.size() > y_dim.size()) {
for (int i = 0; i < axis; ++i) {
y_dims_array[i] = 1;
}
if (axis + y_dim.size() < max_dim) {
for (int i = axis + y_dim.size(); i < max_dim; ++i) {
y_dims_array[i] = 1;
}
}
x_dims_array = x_dim.Vectorize();
for (int i = 0; i < y_dim.size(); ++i) {
y_dims_array[i + axis] = y_dim[i];
}
} else {
for (int i = 0; i < axis; ++i) {
x_dims_array[i] = 1;
}
if (axis + x_dim.size() < max_dim) {
for (int i = axis + x_dim.size(); i < max_dim; ++i) {
x_dims_array[i] = 1;
}
}
y_dims_array = y_dim.Vectorize();
for (int i = 0; i < x_dim.size(); ++i) {
x_dims_array[i + axis] = x_dim[i];
}
}
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] == -1 || y_dims_array[i] == -1) {
out_dims_array[i] = -1;
} else {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
}
}
out_dims_ = DDim(out_dims_array);
}
// infer end
// forward
std::vector<float> x(x_dims_.production());
std::vector<float> y(y_dims_.production());
std::vector<float> out(out_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production());
this->run_forward(&param_, &kernel_, x, y, out.data());
for (int i = 0; i < x_dims_.production(); i++) {
LOG(INFO) << "x_" << i << ": " << x[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
LOG(INFO) << "y_" << i << ": " << y[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
LOG(INFO) << "out_" << i << ": " << out[i];
}
// backward
std::vector<float> out_grad(out_dims_.production());
std::vector<float> x_grad(x_dims_.production());
std::vector<float> y_grad(y_dims_.production());
for (int i = 0; i < out_dims_.production(); i++) {
out_grad[i] = 1.0;
}
this->run_backward(&grad_param_,
&grad_kernel_,
x,
y,
out_grad,
x_grad.data(),
y_grad.data());
for (int i = 0; i < x_grad.size(); i++) {
LOG(INFO) << "x_grad_" << i << ": " << x_grad[i];
}
for (int i = 0; i < y_grad.size(); i++) {
LOG(INFO) << "y_grad_" << i << ": " << y_grad[i];
}
// get numeric gradient
std::vector<float> x_delta(x_dims_.production());
std::vector<float> y_delta(y_dims_.production());
std::vector<float> out_delta(out_dims_.production());
Tensor tensor_x;
Tensor tensor_y;
tensor_x.Resize(x_dims_);
tensor_y.Resize(y_dims_);
grad_param_.X = &tensor_x;
grad_param_.Y = &tensor_y;
elementwise_common(grad_param_, out_grad, x_delta, y_delta, "sub");
float max_grad_delta = 0.0005;
for (int i = 0; i < x_dims_.production(); i++) {
EXPECT_NEAR(x_grad[i], x_delta[i], max_grad_delta);
EXPECT_NEAR(y_grad[i], y_delta[i], max_grad_delta);
}
}
private:
DDim x_dims_;
DDim y_dims_;
DDim out_dims_;
int axis_;
kernel_sub_t kernel_;
grad_kernel_sub_t grad_kernel_;
param_t param_;
grad_param_t grad_param_;
};
void TestNormalCase(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
int axis) {
std::unique_ptr<ElementwiseAddGradTester> tester_add(
new ElementwiseAddGradTester(DDim(x_dims), DDim(y_dims), axis));
std::unique_ptr<ElementwiseSubGradTester> tester_sub(
new ElementwiseSubGradTester(DDim(x_dims), DDim(y_dims), axis));
tester_add->prepare_kernel();
tester_sub->prepare_kernel();
float delta = 0.001;
float max_grad_delta = 0.005;
tester_add->check_grad(delta, max_grad_delta);
tester_sub->check_grad(delta, max_grad_delta);
}
TEST(mul_grad_arm, compute) {
LOG(INFO) << "Test Elementwise grad";
DeviceInfo::Init();
TestNormalCase({3, 2}, {3, 2}, 0);
TestNormalCase({3, 5}, {3, 5}, 1);
TestNormalCase({3, 4, 3}, {3, 4, 3}, 0);
TestNormalCase({9, 2, 5}, {9, 2, 5}, 1);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(elementwise_add_grad, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册