未验证 提交 a0c473f4 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】Matmul double grad composite api (#50452)

* modify name

* merge develop

* original code

* build modify

* success 2*2

* fused dim=1 failed

* success

* modify static

* success for static except dim=1

* delete log

* tmp modify

* success

* success

* add fp1664

* delete fp16 cpu test

* stop windows test

* review modify

* modify tanh test

* modify tanh

* fix_conflixt

* modift static prim

* fix_conflict

* Update test_static_prim.cc

* update

* bug fix
上级 74446b37
......@@ -18,6 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
......@@ -246,6 +247,51 @@ class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
}
};
class MatMulCompositeDoubleGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
// get inputs
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor dout =
this->GetSingleForwardInput(framework::GradVarName("Out"));
paddle::optional<paddle::experimental::Tensor> ddx =
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
paddle::optional<paddle::experimental::Tensor> ddy =
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));
// get attr
bool trans_x = this->Attr<bool>("trans_x");
bool trans_y = this->Attr<bool>("trans_y");
// get output
paddle::experimental::Tensor x_grad_t = this->GetSingleInputGrad("X");
paddle::experimental::Tensor y_grad_t = this->GetSingleInputGrad("Y");
paddle::experimental::Tensor grad_out_grad_t =
this->GetSingleInputGrad(framework::GradVarName("Out"));
// get output ptr
paddle::experimental::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
paddle::experimental::Tensor* y_grad = this->GetOutputPtr(&y_grad_t);
paddle::experimental::Tensor* grad_out_grad =
this->GetOutputPtr(&grad_out_grad_t);
// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
std::string y_grad_name = this->GetOutputName(y_grad_t);
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
VLOG(3) << "Runing matmul_double_grad composite func";
// call composite backward func
prim::matmul_double_grad<prim::DescTensor>(
x, y, dout, ddx, ddy, trans_x, trans_y, x_grad, y_grad, grad_out_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
this->RecoverOutputName(y_grad_t, y_grad_name);
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
}
};
class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -335,6 +381,7 @@ REGISTER_OPERATOR(matmul_v2_grad,
ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
ops::MatMulCompositeDoubleGradOpMaker,
MatMulV2GradInferShapeFunctor);
REGISTER_OPERATOR(matmul_v2_grad_grad,
......
......@@ -13,10 +13,11 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace prim {
using Tensor = paddle::experimental::Tensor;
......@@ -68,7 +69,7 @@ void gather_grad(const Tensor& x,
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
auto grad_x_tmp = grad_out * (1.0 - out.pow(2.0));
auto grad_x_tmp = grad_out * (1 - out * out);
set_output<T>(grad_x_tmp, grad_x);
}
......@@ -210,6 +211,11 @@ void sum_grad(const Tensor& x,
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = out_grad_.expand(IntArray(x_dim));
......@@ -353,6 +359,315 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
}
}
template <typename T>
void matmul_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
bool transpose_x,
bool transpose_y,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> grad_out_dims = vectorize(grad_out.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int dout_ndim = grad_out_dims.size();
// prepare dims for x_ndim <= 1 || y_ndim <= 1
Tensor x_help, y_help, xg_help, yg_help, out_help;
if (x_ndim == 1 && y_ndim == 1) {
transpose_x = false;
transpose_y = false;
x_help = reshape<T>(x, IntArray(std::vector<int64_t>({1, x_dims[0]})));
y_help = reshape<T>(y, IntArray(std::vector<int64_t>({y_dims[0], 1})));
if (grad_x_grad) {
xg_help = reshape<T>(grad_x_grad.get(),
IntArray(std::vector<int64_t>({1, x_dims[0]})));
}
if (grad_y_grad) {
yg_help = reshape<T>(grad_y_grad.get(),
IntArray(std::vector<int64_t>({y_dims[0], 1})));
}
out_help = reshape<T>(grad_out, IntArray(std::vector<int64_t>({1, 1})));
} else if (x_ndim == 1) {
transpose_x = false;
x_help = reshape<T>(x, IntArray(std::vector<int64_t>({1, x_dims[0]})));
y_help = y;
if (grad_x_grad) {
xg_help = reshape<T>(grad_x_grad.get(),
IntArray(std::vector<int64_t>({1, x_dims[0]})));
}
if (grad_y_grad) {
yg_help = grad_y_grad.get();
}
auto tmp_grad_out_dims = grad_out_dims;
tmp_grad_out_dims.insert(tmp_grad_out_dims.begin(), 1);
out_help = reshape<T>(grad_out, IntArray(tmp_grad_out_dims));
} else if (y_ndim == 1) {
transpose_y = false;
x_help = x;
y_help = reshape<T>(y, IntArray(std::vector<int64_t>({y_dims[0], 1})));
if (grad_x_grad) {
xg_help = grad_x_grad.get();
}
if (grad_y_grad) {
yg_help = reshape<T>(grad_y_grad.get(),
IntArray(std::vector<int64_t>({y_dims[0], 1})));
}
auto tmp_grad_out_dims = grad_out_dims;
tmp_grad_out_dims.push_back(1);
out_help = reshape<T>(grad_out, IntArray(tmp_grad_out_dims));
} else {
x_help = x;
y_help = y;
if (grad_x_grad) {
xg_help = grad_x_grad.get();
}
if (grad_y_grad) {
yg_help = grad_y_grad.get();
}
out_help = grad_out;
}
bool is_broadcast = true;
if (x_ndim <= 2 && y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
Tensor dx, dy, ddout_1, ddout_2, ddout;
if (!grad_x_grad && !grad_y_grad) {
x_grad = nullptr;
y_grad = nullptr;
grad_out_grad = nullptr;
return;
} else if (!grad_x_grad) {
y_grad = nullptr;
if (!transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, false, false);
}
} else if (!transpose_x && transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, false);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, false, true);
}
} else if (transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, false, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, true, false);
}
} else {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, true, true);
}
if (grad_out_grad) {
ddout = matmul<T>(x_help, yg_help, true, true);
}
}
} else if (!grad_y_grad) {
x_grad = nullptr;
if (!transpose_x && !transpose_y) {
if (y_grad) {
dy = matmul<T>(xg_help, out_help, true, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, false, false);
}
} else if (!transpose_x && transpose_y) {
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, false, true);
}
} else if (transpose_x && !transpose_y) {
if (y_grad) {
dy = matmul<T>(xg_help, out_help, false, false);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, true, false);
}
} else {
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, true);
}
if (grad_out_grad) {
ddout = matmul<T>(xg_help, y_help, true, true);
}
}
} else {
if (!transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, true);
}
if (y_grad) {
dy = matmul<T>(xg_help, out_help, true, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, false, false);
ddout_2 = matmul<T>(xg_help, y_help, false, false);
ddout = add<T>(ddout_1, ddout_2);
}
} else if (!transpose_x && transpose_y) {
if (x_grad) {
dx = matmul<T>(out_help, yg_help, false, false);
}
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, false, true);
ddout_2 = matmul<T>(xg_help, y_help, false, true);
ddout = add<T>(ddout_1, ddout_2);
}
} else if (transpose_x && !transpose_y) {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, false, true);
}
if (y_grad) {
dy = matmul<T>(xg_help, out_help, false, false);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, true, false);
ddout_2 = matmul<T>(xg_help, y_help, true, false);
ddout = add<T>(ddout_1, ddout_2);
}
} else {
if (x_grad) {
dx = matmul<T>(yg_help, out_help, true, true);
}
if (y_grad) {
dy = matmul<T>(out_help, xg_help, true, true);
}
if (grad_out_grad) {
ddout_1 = matmul<T>(x_help, yg_help, true, true);
ddout_2 = matmul<T>(xg_help, y_help, true, true);
ddout = add<T>(ddout_1, ddout_2);
}
}
}
if (is_broadcast) {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
// Reduce sum to get grad by ReduceSum
if (x_grad) {
auto tx_dims = x_dims;
auto tx_ndim = x_ndim;
auto tdout_ndim = dout_ndim;
if (x_ndim == 1) {
tx_dims = std::vector<int64_t>({1, x_dims[0]});
tx_ndim = x_ndim + 1;
tdout_ndim = dout_ndim + 1;
}
auto x_grad_reduce_dims =
get_reduce_dims(dx, tdout_ndim, tx_ndim, &tx_dims);
if (!x_grad_reduce_dims.empty()) {
dx = sum<T>(dx, IntArray(x_grad_reduce_dims), dy.dtype(), true);
}
reshape<T>(dx, IntArray(tx_dims));
}
if (y_grad) {
auto ty_dims = y_dims;
auto ty_ndim = y_ndim;
auto tdout_ndim = dout_ndim;
if (y_ndim == 1) {
ty_dims = std::vector<int64_t>({y_dims[0], 1});
ty_ndim = y_ndim + 1;
tdout_ndim = dout_ndim + 1;
}
auto y_grad_reduce_dims =
get_reduce_dims(dy, tdout_ndim, ty_ndim, &ty_dims);
if (!y_grad_reduce_dims.empty()) {
dy = sum<T>(dy, IntArray(y_grad_reduce_dims), dy.dtype(), true);
}
reshape<T>(dy, IntArray(ty_dims));
}
}
// recover the original dim of output (delete 1)
std::vector<int64_t> dx_dims =
dx.initialized() ? vectorize(dx.dims()) : std::vector<int64_t>({});
std::vector<int64_t> dy_dims =
dy.initialized() ? vectorize(dy.dims()) : std::vector<int64_t>({});
std::vector<int64_t> ddout_dims =
ddout.initialized() ? vectorize(ddout.dims()) : std::vector<int64_t>({});
if (x_ndim == 1 && y_ndim == 1) {
if (dx.initialized() && dx_dims[0] == 1) {
dx = reshape<T>(dx, IntArray(x_dims));
}
if (dy.initialized() && dy_dims.back() == 1) {
dy = reshape<T>(dy, IntArray(y_dims));
}
if (ddout.initialized() && ddout_dims == std::vector<int64_t>({1, 1})) {
ddout = reshape<T>(ddout, IntArray(std::vector<int64_t>({1})));
}
} else if (x_ndim == 1) {
if (dx.initialized() && dx_dims[0] == 1) {
dx = reshape<T>(dx, IntArray(x_dims));
}
if (ddout.initialized() && ddout_dims[0] == 1) {
ddout = reshape<T>(ddout,
IntArray(std::vector<int64_t>(
{ddout_dims.cbegin() + 1, ddout_dims.cend()})));
}
} else if (y_ndim == 1) {
if (dy.initialized() && dy_dims.back() == 1) {
dy = reshape<T>(dy, IntArray(y_dims));
}
if (ddout.initialized() && ddout_dims.back() == 1) {
ddout = reshape<T>(ddout,
IntArray(std::vector<int64_t>(
{ddout_dims.cbegin(),
ddout_dims.cbegin() + ddout_dims.size() - 1})));
}
}
if (x_grad) {
set_output<T>(dx, x_grad);
}
if (y_grad) {
set_output<T>(dy, y_grad);
}
if (grad_out_grad) {
set_output<T>(ddout, grad_out_grad);
}
}
template <typename T>
void slice_grad(const Tensor& input,
const Tensor& out_grad,
......
......@@ -17,30 +17,33 @@
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace prim {
// We put some api like utils here
template <typename T>
paddle::experimental::Tensor empty(const paddle::experimental::IntArray& shape,
paddle::experimental::DataType dype,
const paddle::Place& place);
Tensor empty(const paddle::experimental::IntArray& shape,
paddle::experimental::DataType dype,
const paddle::Place& place);
template <typename T>
paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x,
paddle::experimental::DataType dtype,
const paddle::Place& place);
Tensor empty_like(const Tensor& x,
paddle::experimental::DataType dtype,
const paddle::Place& place);
// copy tensor for output ptr, in static need use assigh op
template <typename T>
void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out);
void by_pass(const Tensor& x, Tensor* out);
// set output ptr impl with tmp ptr impl,in dygraph OutGradMeta should be set
template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x);
void set_output(const Tensor& x_tmp, Tensor* x);
// These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
......@@ -78,11 +81,38 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
return get_reduce_dims_from_out(out_dims, x_dims);
}
static std::vector<int> get_reduce_dims(const Tensor& dx,
const int& dout_ndim,
const int& x_ndim,
std::vector<int64_t>* x_dims) {
// this branch for broadcast with 1dim, we make 1dim to 2dim which make
// ddout_ndim > dout_dim, but ddout_ndim just can be used when grad_out_grad
// != nullptr
if (dout_ndim < x_ndim) {
return std::vector<int>({});
}
const std::vector<std::int64_t> dx_dims = phi::vectorize(dx.dims());
std::vector<std::int64_t> broadcast_dims(dout_ndim);
std::fill(
broadcast_dims.data(), broadcast_dims.data() + dout_ndim - x_ndim, 1);
std::copy(x_dims->data(),
x_dims->data() + x_ndim,
broadcast_dims.data() + dout_ndim - x_ndim);
std::vector<int> reduce_dims;
for (int i = 0; i <= dout_ndim - 3; i++) {
if (dx_dims[i] != 1 && broadcast_dims[i] == 1) {
reduce_dims.push_back(i);
}
}
return reduce_dims;
}
// TODO(cxxly): Check and throws InvalidCastException when overflow.
template <typename SRC_T, typename DST_T>
static std::vector<DST_T> unsafe_vector_cast(const std::vector<SRC_T>& src) {
std::vector<DST_T> dst(src.begin(), src.end());
return dst;
}
} // namespace prim
} // namespace paddle
......@@ -194,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(5));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(4));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
......@@ -204,41 +204,34 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(grad_ops[0]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[0]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[0]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[0]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[0]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[0]->Inputs().at("Y")[0], "b");
ASSERT_EQ(grad_ops[0]->Inputs().at("X")[0], "b");
ASSERT_EQ(grad_ops[1]->Type(), "elementwise_pow");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0], "b");
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
ASSERT_EQ(grad_ops[1]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[1]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[2]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0],
ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0],
grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}
TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
......
......@@ -47,6 +47,8 @@ class DescTensor : public phi::ExtendedTensor,
const phi::Place& place() const override { return place_; }
bool initialized() const override { return desc_ptr_ != nullptr; }
// TODO(jiabin): override more operators here.
private:
......
......@@ -28,6 +28,7 @@
namespace paddle {
namespace prim {
class UniqueNameGenerator {
public:
explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {}
......@@ -94,7 +95,7 @@ class StaticCompositeContext {
: current_block_desc_(nullptr),
generator_(new UniqueNameGenerator()),
skip_comp_ops_({"matmul_v2"}) {}
// TODO(Ruting) test cases when fix static backward
framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
std::unordered_set<std::string> skip_comp_ops_;
......
......@@ -722,6 +722,7 @@
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x=false, transpose_y=false)
backward : matmul_triple_grad
optional : grad_x_grad, grad_y_grad
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
# vector * vector out.shape = (1)
# matrix * vector out.shape = (2)
# vector * matrix out.shape = (3)
# batched matrix * batched matrix 4 for trans out.shape = (2, 3, 5)
# batched matrix * broadcasted vector out.shape = (2, 3)
# batched matrix * broadcasted matrix out.shape = (2, 3, 5, 4)
TOLERANCE = {
"float16": {"rtol": 1e-3, "atol": 1e-3},
"float32": {"rtol": 1e-6, "atol": 1e-6},
"float64": {"rtol": 1e-15, "atol": 1e-15},
}
@param.parameterized_class(
('primal0', 'primal1', 'trans_0', 'trans_1', 'dtype'),
[
(
np.random.rand(2),
np.random.rand(2),
False,
False,
np.float32,
),
(
np.random.rand(2, 3),
np.random.rand(3),
False,
False,
np.float32,
),
(
np.random.rand(2),
np.random.rand(2, 3),
False,
False,
np.float32,
),
(
np.random.rand(2),
np.random.rand(3, 2),
False,
True,
np.float32,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 4, 5),
False,
False,
np.float32,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 4, 5),
True,
False,
np.float32,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 5, 4),
False,
True,
np.float32,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
True,
True,
np.float32,
),
(
np.random.rand(2, 3, 4),
np.random.rand(4),
False,
False,
np.float32,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
False,
False,
np.float32,
),
(
np.random.rand(2),
np.random.rand(2),
False,
False,
np.float16,
),
(
np.random.rand(2, 3),
np.random.rand(3),
False,
False,
np.float16,
),
(
np.random.rand(2),
np.random.rand(2, 3),
False,
False,
np.float16,
),
(
np.random.rand(2),
np.random.rand(3, 2),
False,
True,
np.float16,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 4, 5),
False,
False,
np.float16,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 4, 5),
True,
False,
np.float16,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 5, 4),
False,
True,
np.float16,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
True,
True,
np.float16,
),
(
np.random.rand(2, 3, 4),
np.random.rand(4),
False,
False,
np.float16,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
False,
False,
np.float16,
),
(
np.random.rand(2),
np.random.rand(2),
False,
False,
np.float64,
),
(
np.random.rand(2, 3),
np.random.rand(3),
False,
False,
np.float64,
),
(
np.random.rand(2),
np.random.rand(2, 3),
False,
False,
np.float64,
),
(
np.random.rand(2),
np.random.rand(3, 2),
False,
True,
np.float64,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 5, 4),
False,
True,
np.float64,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 4, 5),
False,
False,
np.float64,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 4, 5),
True,
False,
np.float64,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
True,
True,
np.float64,
),
(
np.random.rand(2, 3, 4),
np.random.rand(4),
False,
False,
np.float64,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
False,
False,
np.float64,
),
],
)
class TestMatmulDoubleGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
cls.trans_0 = cls.trans_0
cls.trans_1 = cls.trans_1
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_matmul_grad_comp(self):
def actual(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
out = paddle.matmul(x, y, trans_0, trans_1)
dout = paddle.ones_like(out, dtype=dtype_)
dout.stop_gradient = False
res = paddle.grad(
[out], [x, y], dout, create_graph=True, retain_graph=True
)
res_double = paddle.grad(
res, [x, y, dout], create_graph=True, retain_graph=True
)
return (
res_double[0].numpy(),
res_double[1].numpy(),
res_double[2].numpy(),
)
def desired(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
out = paddle.matmul(x, y, trans_0, trans_1)
dout = paddle.ones_like(out, dtype=dtype_)
dout.stop_gradient = False
res = paddle.grad(
out, [x, y], dout, create_graph=True, retain_graph=True
)
res_double = paddle.grad(
res, [x, y, dout], create_graph=True, retain_graph=True
)
return (
res_double[0].numpy(),
res_double[1].numpy(),
res_double[2].numpy(),
)
d_type = "float32"
if self.primal0.dtype == np.float16:
d_type = "float16"
elif self.primal0.dtype == np.float64:
d_type = "float64"
if paddle.device.get_device() == "cpu" and d_type == "float16":
# matmul fp16 cpu not supposed
pass
else:
dx, dy, ddout = actual(
self.primal0, self.primal1, self.trans_0, self.trans_1, d_type
)
dx_, dy_, ddout_ = desired(
self.primal0, self.primal1, self.trans_0, self.trans_1, d_type
)
np.testing.assert_allclose(
actual=dx,
desired=dx_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
np.testing.assert_allclose(
actual=dy,
desired=dy_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
np.testing.assert_allclose(
actual=ddout,
desired=ddout_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
@param.parameterized_class(
('primal0', 'primal1', 'trans_0', 'trans_1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(4),
False,
False,
np.float16,
),
(
np.random.rand(2, 3, 4),
np.random.rand(4),
False,
False,
np.float32,
),
(
np.random.rand(2, 3, 4),
np.random.rand(4),
False,
False,
np.float64,
),
(
np.random.rand(2, 2, 3),
np.random.rand(2, 3, 2),
False,
False,
np.float16,
),
(
np.random.rand(2, 2, 3),
np.random.rand(2, 3, 2),
False,
False,
np.float32,
),
(
np.random.rand(2, 2, 3),
np.random.rand(2, 3, 2),
False,
False,
np.float64,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
True,
True,
np.float64,
),
(
np.random.rand(2, 2, 3),
np.random.rand(1, 3, 2),
False,
False,
np.float64,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
False,
False,
np.float32,
),
],
)
class TestMatmulTribleGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
cls.trans_0 = cls.trans_0
cls.trans_1 = cls.trans_1
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_matmul_grad_comp(self):
def actual(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
out = paddle.matmul(x, y, trans_0, trans_1)
dout = paddle.ones_like(out, dtype=dtype_)
dout.stop_gradient = False
ddx = paddle.ones_like(x, dtype=dtype_)
ddx.stop_gradient = False
ddy = paddle.ones_like(y, dtype=dtype_)
ddy.stop_gradient = False
res = paddle.grad(
[out], [x, y], dout, create_graph=True, retain_graph=True
)
res_double = paddle.grad(
res,
[x, y, dout],
[ddx, ddy],
create_graph=True,
retain_graph=True,
)
res_triple = paddle.grad(
res_double,
[x, y, dout, ddx, ddy],
create_graph=False,
retain_graph=False,
)
return (
res_double[0].numpy(),
res_double[1].numpy(),
res_double[2].numpy(),
res_triple[0].numpy(),
res_triple[1].numpy(),
)
def desired(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
out = paddle.matmul(x, y, trans_0, trans_1)
dout = paddle.ones_like(out, dtype=dtype_)
dout.stop_gradient = False
ddx = paddle.ones_like(x, dtype=dtype_)
ddx.stop_gradient = False
ddy = paddle.ones_like(y, dtype=dtype_)
ddy.stop_gradient = False
res = paddle.grad(
[out], [x, y], [dout], create_graph=True, retain_graph=True
)
res_double = paddle.grad(
res,
[x, y, dout],
[ddx, ddy],
create_graph=True,
retain_graph=True,
)
res_triple = paddle.grad(
res_double,
[x, y, dout, ddx, ddy],
create_graph=False,
retain_graph=True,
)
return (
res_double[0].numpy(),
res_double[1].numpy(),
res_double[2].numpy(),
res_triple[0].numpy(),
res_triple[1].numpy(),
)
d_type = "float32"
if self.primal0.dtype == np.float16:
d_type = "float16"
elif self.primal0.dtype == np.float64:
d_type = "float64"
if paddle.device.get_device() == "cpu" and d_type == "float16":
# matmul fp16 cpu not supposed
pass
else:
dx, dy, ddout, dx2, dy2 = actual(
self.primal0, self.primal1, self.trans_0, self.trans_1, d_type
)
dx_, dy_, ddout_, dx2_, dy2_ = desired(
self.primal0, self.primal1, self.trans_0, self.trans_1, d_type
)
np.testing.assert_allclose(
actual=dx,
desired=dx_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
np.testing.assert_allclose(
actual=dy,
desired=dy_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
np.testing.assert_allclose(
actual=ddout,
desired=ddout_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
np.testing.assert_allclose(
actual=dx2,
desired=dx2_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
np.testing.assert_allclose(
actual=dy2,
desired=dy2_,
rtol=TOLERANCE[d_type]['rtol'],
atol=TOLERANCE[d_type]['atol'],
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -99,6 +99,18 @@ class TestSumGradComp(unittest.TestCase):
atol=0,
)
def test_sum_grad_comp_6(self):
self.primal = np.random.rand(3, 2, 5)
self.cotangent = np.random.rand(3, 1, 1)
paddle.disable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [-2, -1], True),
desired=desired(self.primal, self.cotangent, [-2, -1], True),
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
# when dim = 1 reshape op will be deleted by backward algorithm ,
# it's better to use matmul_grad in static composite pattern
# batched matrix * batched matrix 4 for trans out.shape = (2, 3, 5)
# batched matrix * broadcasted vector out.shape = (2, 3)
# batched matrix * broadcasted matrix out.shape = (2, 3, 5, 4)
TOLERANCE = {
"float16": {"rtol": 1e-3, "atol": 1e-3},
"float32": {"rtol": 1e-6, "atol": 1e-6},
"float64": {"rtol": 1e-15, "atol": 1e-15},
}
# TODO(ruting) test cases when fix static backward
@param.parameterized_class(
('primal0', 'primal1', 'primal2', 'trans_0', 'trans_1', 'dtype'),
[
# (
# np.random.rand(2),
# np.random.rand(2),
# np.random.rand(1),
# False,
# False,
# ),
# (
# np.random.rand(2, 3),
# np.random.rand(3),
# np.random.rand(2),
# False,
# False,
# ),
# (
# np.random.rand(2),
# np.random.rand(2, 3),
# np.random.rand(3),
# False,
# False,
# ),
# (
# np.random.rand(2),
# np.random.rand(3, 2),
# np.random.rand(3),
# False,
# True,
# ),
# (
# np.random.rand(2, 3, 4),
# np.random.rand(4),
# np.random.rand(2, 3),
# False,
# False,
# ),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 4, 5),
np.random.rand(2, 3, 5),
False,
False,
np.float16,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 4, 5),
np.random.rand(2, 3, 5),
True,
False,
np.float16,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 5, 4),
np.random.rand(2, 3, 5),
False,
True,
np.float16,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
np.random.rand(2, 3, 5),
True,
True,
np.float16,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
np.random.rand(2, 3, 5, 4),
False,
False,
np.float16,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 4, 5),
np.random.rand(2, 3, 5),
False,
False,
np.float32,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 4, 5),
np.random.rand(2, 3, 5),
True,
False,
np.float32,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 5, 4),
np.random.rand(2, 3, 5),
False,
True,
np.float32,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
np.random.rand(2, 3, 5),
True,
True,
np.float32,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
np.random.rand(2, 3, 5, 4),
False,
False,
np.float32,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 4, 5),
np.random.rand(2, 3, 5),
False,
False,
np.float64,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 4, 5),
np.random.rand(2, 3, 5),
True,
False,
np.float64,
),
(
np.random.rand(2, 3, 4),
np.random.rand(2, 5, 4),
np.random.rand(2, 3, 5),
False,
True,
np.float64,
),
(
np.random.rand(2, 4, 3),
np.random.rand(2, 5, 4),
np.random.rand(2, 3, 5),
True,
True,
np.float64,
),
(
np.random.rand(2, 1, 5, 2),
np.random.rand(1, 3, 2, 4),
np.random.rand(2, 3, 5, 4),
False,
False,
np.float64,
),
],
)
class TestMatmulDoubleGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
cls.primal2 = cls.primal2.astype(cls.dtype)
cls.trans_0 = cls.trans_0
cls.trans_1 = cls.trans_1
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_matmul_grad_comp(self):
def actual(primal0, primal1, primal2, trans_0, trans_1):
core._set_prim_backward_enabled(True)
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
y = paddle.static.data('primal1', primal1.shape, primal1.dtype)
z = paddle.static.data('primal2', primal2.shape, primal2.dtype)
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
out = paddle.matmul(x, y, trans_0, trans_1)
res = paddle.static.gradients([out], [x, y], z)
res_double = paddle.static.gradients(res, [x, y, z])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'primal1': primal1,
'primal2': primal2,
},
fetch_list=[
res_double[0].name,
res_double[1].name,
res_double[2].name,
],
)
return out[0], out[1], out[2]
def desired(primal0, primal1, primal2, trans_0, trans_1):
core._set_prim_backward_enabled(False)
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
y = paddle.static.data('primal1', primal1.shape, primal1.dtype)
z = paddle.static.data('primal2', primal2.shape, primal2.dtype)
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
out = paddle.matmul(x, y, trans_0, trans_1)
res = paddle.static.gradients([out], [x, y], z)
res_double = paddle.static.gradients(res, [x, y, z])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'primal1': primal1,
'primal2': primal2,
},
fetch_list=[
res_double[0].name,
res_double[1].name,
res_double[2].name,
],
)
return out[0], out[1], out[2]
dtype = 'float32'
if self.primal0.dtype == np.float16:
dtype = 'float16'
elif self.primal0.dtype == np.float16:
dtype = 'float64'
if paddle.device.get_device() == "cpu" and dtype == "float16":
# matmul fp16 cpu not supposed
pass
else:
dx, dy, ddout = actual(
self.primal0,
self.primal1,
self.primal2,
self.trans_0,
self.trans_1,
)
dx_, dy_, ddout_ = desired(
self.primal0,
self.primal1,
self.primal2,
self.trans_0,
self.trans_1,
)
np.testing.assert_allclose(
actual=dx,
desired=dx_,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
np.testing.assert_allclose(
actual=dy,
desired=dy_,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
np.testing.assert_allclose(
actual=ddout,
desired=ddout_,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -119,6 +119,18 @@ class TestSumGradComp(unittest.TestCase):
atol=0,
)
def test_sum_grad_comp_6(self):
self.primal = np.random.rand(3, 2, 5)
self.cotangent = np.random.rand(3, 1, 1)
paddle.enable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [-2, -1], True),
desired=desired(self.primal, self.cotangent, [-2, -1], True),
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
......@@ -42,8 +42,7 @@ from paddle.fluid import core, framework
set(),
tuple(),
(
'fill_constant',
'elementwise_pow',
'elementwise_mul',
'fill_constant',
'elementwise_sub',
'elementwise_mul',
......
......@@ -28,8 +28,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = (
'fill_constant',
'elementwise_pow',
'elementwise_mul',
'fill_constant',
'elementwise_sub',
'elementwise_mul',
......
......@@ -181,6 +181,7 @@ disable_win_inference_test="^trt_quant_int8_yolov3_r50_test$|\
^test_tensordot$|\
^disable_win_inference_test$|\
^test_imperative_double_grad$|\
^test_comp_eager_matmul_double_grad$|\
^test_imperative_triple_grad$"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册