未验证 提交 fefc28ca 编写于 作者: M mapingshuo 提交者: GitHub

Mul grad (#3201)

* rm grad code

* add mul_grad, test=develop
上级 13deb11e
......@@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de
if(LITE_WITH_TRAIN)
add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
endif()
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/mul_grad_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void MulGradCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
}
void MulGradCompute::Run() {
// step1 flatten_2d
auto& param = Param<param_t>();
const auto x_dims = param.x->dims();
const auto y_dims = param.y->dims();
const auto out_dims = param.output_grad->dims();
m_ = static_cast<int>(x_dims.Slice(0, param.x_num_col_dims).production());
k_ = static_cast<int>(
x_dims.Slice(param.x_num_col_dims, x_dims.size()).production());
n_ = static_cast<int>(
y_dims.Slice(param.y_num_col_dims, y_dims.size()).production());
const auto* out_grad_data = param.output_grad->data<float>();
const auto* x_data = param.x->data<float>();
const auto* y_data = param.y->data<float>();
float* x_grad_data;
float* y_grad_data;
if (param.x_grad) {
x_grad_data = param.x_grad->mutable_data<float>();
}
if (param.y_grad) {
y_grad_data = param.y_grad->mutable_data<float>();
}
paddle::lite::operators::ActivationParam act_param;
act_param.has_active = false;
// out_grad * y^T = x_grad
// (m, n), (n, k) -> (m, k)
auto& ctx = this->ctx_->template As<ARMContext>();
if (param.x_grad) {
if (m_ == 1) {
lite::arm::math::sgemv(y_data,
out_grad_data,
x_grad_data,
false,
k_, // M
n_, // N
false,
nullptr,
false,
lite_api::ActivationType::kIndentity,
&ctx);
} else {
paddle::lite::arm::math::sgemm(false,
true, // is_transB,
m_, // M
k_, // N
n_, // K
1.0f, // alpha
out_grad_data, // A
n_, // lda
y_data, // B
n_, // ldb
0.f, // beta
x_grad_data, // C
k_, // ldc
NULL, // bias
false, // is_bias
act_param, // act_param
&ctx); // ctx
}
}
// x^T * out_grad = y_grad
// (k, m) (m, n) -> (k, n)
if (param.y_grad) {
if (n_ == 1) {
lite::arm::math::sgemv(x_data,
out_grad_data,
y_grad_data,
true,
k_, // M
m_, // N
false,
nullptr,
false,
lite_api::ActivationType::kIndentity,
&ctx);
} else {
paddle::lite::arm::math::sgemm(true, // is_transA
false, // is_transB,
k_, // M
n_, // N
m_, // K
1.0f, // alpha
x_data, // A
k_, // lda
out_grad_data, // B
n_, // ldb
0.f, // beta
y_grad_data, // C
n_, // ldc
NULL, // bias
false, // is_bias
act_param, // act_param
&ctx); // ctx
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::MulGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class MulGradCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MulGradParam;
void PrepareForRun() override;
void Run() override;
virtual ~MulGradCompute() = default;
private:
int m_, n_, k_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if (LITE_WITH_TRAIN)
add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS})
add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS})
add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS})
add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS})
endif()
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/mul_grad_op.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace operators {
bool MulGradOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.y);
CHECK_OR_FALSE(param_.output_grad);
CHECK_OR_FALSE(param_.x_grad || param_.y_grad);
CHECK_OR_FALSE(param_.x_num_col_dims);
CHECK_OR_FALSE(param_.y_num_col_dims);
const auto x_dims = param_.x->dims();
const auto y_dims = param_.y->dims();
const auto out_dims = param_.output_grad->dims();
CHECK_GT_OR_FALSE(x_dims.size(), static_cast<size_t>(param_.x_num_col_dims));
CHECK_GT_OR_FALSE(y_dims.size(), static_cast<size_t>(param_.y_num_col_dims));
auto x_flatten_dims = flatten_2d(x_dims, param_.x_num_col_dims);
auto y_flatten_dims = flatten_2d(y_dims, param_.y_num_col_dims);
auto out_flatten_dims = flatten_2d(out_dims, param_.x_num_col_dims);
// Out = X * Y;
CHECK_EQ_OR_FALSE(x_flatten_dims[1], y_flatten_dims[0]);
CHECK_EQ_OR_FALSE(x_flatten_dims[0], out_flatten_dims[0]);
CHECK_EQ_OR_FALSE(y_flatten_dims[1], out_flatten_dims[1]);
return true;
}
bool MulGradOpLite::InferShape() const {
const auto x_dims = param_.x->dims();
const auto y_dims = param_.y->dims();
if (param_.x_grad) {
param_.x_grad->Resize(x_dims);
param_.x_grad->set_lod(param_.x->lod());
}
if (param_.y_grad) {
param_.y_grad->Resize(y_dims);
param_.y_grad->set_lod(param_.y->lod());
}
}
bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Input("Out@GRAD").empty());
CHECK(!op_desc.Output("X@GRAD").empty() || !op_desc.Output("Y@GRAD").empty())
<< "at least one of 'X@GRAD' and 'Y@GRAD' is not empty";
auto *x_var = scope->FindVar(op_desc.Input("X").front());
CHECK(x_var);
param_.x = &x_var->Get<Tensor>();
auto *y_var = scope->FindVar(op_desc.Input("Y").front());
CHECK(y_var);
param_.y = &y_var->Get<Tensor>();
auto *out_grad_var = scope->FindVar(op_desc.Input("Out@GRAD").front());
CHECK(out_grad_var);
param_.output_grad = &out_grad_var->Get<Tensor>();
if (!op_desc.Output("X@GRAD").empty()) {
auto *x_grad_var = scope->FindVar(op_desc.Output("X@GRAD").front());
CHECK(x_grad_var);
param_.x_grad = x_grad_var->GetMutable<Tensor>();
}
if (!op_desc.Output("Y@GRAD").empty()) {
auto *y_grad_var = scope->FindVar(op_desc.Output("Y@GRAD").front());
CHECK(y_grad_var);
param_.y_grad = y_grad_var->GetMutable<Tensor>();
}
param_.x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims");
param_.y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(mul_grad, paddle::lite::operators::MulGradOpLite);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class MulGradOpLite : public OpLite {
public:
MulGradOpLite() {}
explicit MulGradOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "mul_grad"; }
private:
mutable MulGradParam param_;
};
std::vector<int64_t> flatten_2d(DDim dims, int num_col_dims) {
std::vector<int64_t> flatten_dims{1, 1};
for (int i = 0; i < dims.size(); i++) {
if (i < num_col_dims) {
flatten_dims[0] *= dims[i];
} else {
flatten_dims[1] *= dims[i];
}
}
return flatten_dims;
}
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -66,28 +66,6 @@ class MulOpLite : public OpLite {
mutable MulParam param_;
};
#ifdef LITE_WITH_TRAIN
class MulGradOpLite : public OpLite {
public:
MulGradOpLite() {}
explicit MulGradOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "mul_grad"; }
private:
mutable MulGradParam param_;
};
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA)
if (LITE_WITH_TRAIN)
lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......
......@@ -109,6 +109,7 @@ void TestMul(const std::vector<int64_t>& x_dims,
int y_num_col_dims,
const Place& place,
float abs_error) {
LOG(INFO) << "run test arm";
std::unique_ptr<arena::TestCase> tester(new MulComputeTester(place,
"def",
DDim(x_dims),
......@@ -131,7 +132,6 @@ TEST(Mul, precision) {
#else
return;
#endif
TestMul({4, 5}, {5, 4}, 1, 1, place, abs_error);
TestMul({4, 5}, {5, 4, 3, 2}, 1, 1, place, abs_error);
TestMul({4, 20}, {5, 4, 3, 2}, 1, 2, place, abs_error);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/mul_grad_compute.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/mul_compute.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
using param_t = operators::MulParam;
using grad_param_t = operators::MulGradParam;
using kernel_t = MulCompute;
using grad_kernel_t = MulGradCompute;
class MulGradTester {
public:
explicit MulGradTester(const DDim& x_dims,
const DDim& y_dims,
int x_num_col_dims,
int y_num_col_dims)
: x_dims_(x_dims),
y_dims_(y_dims),
x_num_col_dims_(x_num_col_dims),
y_num_col_dims_(y_num_col_dims) {}
void prepare_kernel() {
std::unique_ptr<KernelContext> ctx1(new KernelContext);
ctx1->As<ARMContext>();
kernel_.SetContext(std::move(ctx1));
std::unique_ptr<KernelContext> ctx2(new KernelContext);
ctx2->As<ARMContext>();
delta_kernel_.SetContext(std::move(ctx2));
std::unique_ptr<KernelContext> ctx3(new KernelContext);
ctx3->As<ARMContext>();
grad_kernel_.SetContext(std::move(ctx3));
}
void run_forward(param_t* param,
kernel_t* kernel,
const std::vector<float>& x_vec,
const std::vector<float>& y_vec,
float* out_vec) {
Tensor x;
Tensor y;
Tensor output;
x.Resize(x_dims_);
y.Resize(y_dims_);
output.Resize(DDim(out_dims_));
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = x_vec[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_data[i] = y_vec[i];
}
param->x = &x;
param->y = &y;
param->output = &output;
param->x_num_col_dims = x_num_col_dims_;
param->y_num_col_dims = y_num_col_dims_;
kernel->SetParam(*param);
kernel->Launch();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < out_dims_.production(); i++) {
out_vec[i] = output_data[i];
}
}
void run_backward(grad_param_t* param,
grad_kernel_t* kernel,
const std::vector<float>& x_vec,
const std::vector<float>& y_vec,
const std::vector<float>& out_grad_vec,
float* x_grad_vec,
float* y_grad_vec) {
Tensor x;
Tensor x_grad;
Tensor y;
Tensor y_grad;
Tensor out_grad;
x.Resize(x_dims_);
x_grad.Resize(x_dims_);
y.Resize(y_dims_);
y_grad.Resize(y_dims_);
out_grad.Resize(out_dims_);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* out_grad_data = out_grad.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = x_vec[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_data[i] = y_vec[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
out_grad_data[i] = out_grad_vec[i];
}
param->x = &x;
param->x_grad = &x_grad;
param->y = &y;
param->y_grad = &y_grad;
param->output_grad = &out_grad;
param->x_num_col_dims = x_num_col_dims_;
param->y_num_col_dims = y_num_col_dims_;
kernel->SetParam(*param);
kernel->Launch();
auto* x_grad_data = x_grad.mutable_data<float>();
auto* y_grad_data = y_grad.mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
x_grad_vec[i] = x_grad_data[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
y_grad_vec[i] = y_grad_data[i];
}
}
void check_grad() {
std::vector<int64_t> out_shape;
for (int i = 0; i < x_num_col_dims_; i++) {
out_shape.push_back(x_dims_[i]);
}
for (int i = y_num_col_dims_; i < y_dims_.size(); i++) {
out_shape.push_back(y_dims_[i]);
}
out_dims_ = DDim(out_shape);
// forward
std::vector<float> x(x_dims_.production());
std::vector<float> y(y_dims_.production());
std::vector<float> out(out_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production());
this->run_forward(&param_, &kernel_, x, y, out.data());
for (int i = 0; i < x_dims_.production(); i++) {
LOG(INFO) << "x_" << i << ": " << x[i];
}
for (int i = 0; i < y_dims_.production(); i++) {
LOG(INFO) << "y_" << i << ": " << y[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
LOG(INFO) << "out_" << i << ": " << out[i];
}
// backward
std::vector<float> out_grad(out_dims_.production());
std::vector<float> x_grad(x_dims_.production());
std::vector<float> y_grad(y_dims_.production());
for (int i = 0; i < out_dims_.production(); i++) {
out_grad[i] = 1.0;
}
this->run_backward(&grad_param_,
&grad_kernel_,
x,
y,
out_grad,
x_grad.data(),
y_grad.data());
// get numeric gradient
std::vector<float> x_delta(x_dims_.production());
std::vector<float> y_delta(y_dims_.production());
std::vector<float> out_delta(out_dims_.production());
float delta = 0.001;
float max_grad_delta = 0.005;
for (int i = 0; i < x_dims_.production(); i++) {
LOG(INFO) << "--------------------";
LOG(INFO) << "delta: " << delta;
LOG(INFO) << "max_grad_delta: " << max_grad_delta;
for (int j = 0; j < x_dims_.production(); j++) {
// x_delta[j] = i == j ? x[j] + delta : x[j];
if (i == j) {
x_delta[j] = x[j] + delta;
} else {
x_delta[j] = x[j];
}
}
this->run_forward(
&delta_param_, &delta_kernel_, x_delta, y, out_delta.data());
for (int j = 0; j < x_dims_.production(); j++) {
LOG(INFO) << "x_" << j << ": " << x[j];
LOG(INFO) << "x_delta_" << j << ": " << x_delta[j];
}
for (int j = 0; j < y_dims_.production(); j++) {
LOG(INFO) << "y_" << j << ": " << y[j];
}
for (int j = 0; j < out_dims_.production(); j++) {
LOG(INFO) << "out_delta_" << j << ": " << out_delta[j];
}
float sum = 0;
for (int j = 0; j < out_dims_.production(); j++) {
sum += (out_delta[j] - out[j]);
}
LOG(INFO) << "x_grad_" << i << ": " << x_grad[i];
LOG(INFO) << "x_grad_num_" << i << ": " << sum / delta;
EXPECT_NEAR(x_grad[i], sum / delta, max_grad_delta);
}
for (int i = 0; i < y_dims_.production(); i++) {
for (int j = 0; j < y_dims_.production(); j++) {
y_delta[j] = i == j ? y[j] + delta : y[j];
}
this->run_forward(
&delta_param_, &delta_kernel_, x, y_delta, out_delta.data());
float sum = 0;
for (int j = 0; j < out_dims_.production(); j++) {
sum += out_delta[j] - out[j];
}
LOG(INFO) << "y_grad_" << i << ": " << y_grad[i];
LOG(INFO) << "y_grad_num_" << i << ": " << sum / delta;
EXPECT_NEAR(y_grad[i], sum / delta, max_grad_delta);
}
}
private:
DDim x_dims_;
DDim y_dims_;
DDim out_dims_;
int x_num_col_dims_;
int y_num_col_dims_;
kernel_t kernel_;
kernel_t delta_kernel_;
grad_kernel_t grad_kernel_;
param_t param_;
param_t delta_param_;
grad_param_t grad_param_;
};
void TestNormalCase(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
int x_num_col_dims,
int y_num_col_dims) {
std::unique_ptr<MulGradTester> tester(new MulGradTester(
DDim(x_dims), DDim(y_dims), x_num_col_dims, y_num_col_dims));
tester->prepare_kernel();
float delta = 0.001;
float max_grad_delta = 0.005;
tester->check_grad();
}
TEST(mul_grad_arm, compute) {
LOG(INFO) << "Test Mul grad";
DeviceInfo::Init();
TestNormalCase({1, 3}, {3, 2}, 1, 1);
TestNormalCase({3, 2}, {2, 1}, 1, 1);
TestNormalCase({3, 1}, {1, 7}, 1, 1);
TestNormalCase({2, 3}, {3, 2}, 1, 1);
TestNormalCase({4, 5}, {5, 4}, 1, 1);
TestNormalCase({4, 5}, {5, 4, 3, 2}, 1, 1);
TestNormalCase({3, 4}, {2, 2, 3}, 1, 2);
TestNormalCase({4, 20}, {5, 4, 3, 2}, 1, 2);
TestNormalCase({4, 60}, {5, 4, 3, 2}, 1, 3);
TestNormalCase({2, 3, 4, 5}, {60, 4}, 1, 1);
TestNormalCase({2, 3, 4, 5}, {20, 4}, 2, 1);
TestNormalCase({2, 3, 4, 5}, {5, 4}, 3, 1);
TestNormalCase({2, 3, 4, 5}, {60, 3, 4, 5}, 1, 1);
TestNormalCase({2, 3, 4, 5}, {4, 5, 6, 2}, 2, 2);
TestNormalCase({2, 3, 4, 5}, {5, 1, 4, 2}, 3, 2);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul_grad, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册