提交 dd5f33e1 编写于 作者: W wangruting

original code

上级 3ebc0f73
...@@ -25,6 +25,10 @@ white_ops_list = [ ...@@ -25,6 +25,10 @@ white_ops_list = [
"divide", "divide",
"sum", "sum",
"exp", "exp",
"matmul",
"dot",
"transpose",
"add",
] ]
inplace_out_type_map = { inplace_out_type_map = {
......
...@@ -38,6 +38,24 @@ ...@@ -38,6 +38,24 @@
namespace paddle { namespace paddle {
namespace prim { namespace prim {
template <>
Tensor add<DescTensor>(const Tensor& x, const Tensor& y) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("elementwise_add");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <> template <>
Tensor pow<DescTensor>(const Tensor& x, const Scalar& y) { Tensor pow<DescTensor>(const Tensor& x, const Scalar& y) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
...@@ -77,6 +95,29 @@ Tensor scale<DescTensor>(const Tensor& x, ...@@ -77,6 +95,29 @@ Tensor scale<DescTensor>(const Tensor& x,
return out; return out;
} }
template <>
Tensor matmul<DescTensor>(const Tensor& x,
const Tensor& y,
bool transpose_x,
bool transpose_y) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("MatMul");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("transpose_X", transpose_x);
op->SetAttr("transpose_Y", transpose_y);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <> template <>
Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) { Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape // Grad infershape
...@@ -236,6 +277,41 @@ Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) { ...@@ -236,6 +277,41 @@ Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
return out; return out;
} }
template <>
Tensor transpose<Tensor>(const Tensor& x, const std::vector<int>& perm) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("transpose");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("axis", perm);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor dot<DescTensor>(const Tensor& x, const Tensor& y) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("dot");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <> template <>
Tensor exp<DescTensor>(const Tensor& x) { Tensor exp<DescTensor>(const Tensor& x) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual/utils/utils.h"
...@@ -170,7 +171,7 @@ void divide_grad(const Tensor& x, ...@@ -170,7 +171,7 @@ void divide_grad(const Tensor& x,
Tensor* dx, Tensor* dx,
Tensor* dy) { Tensor* dy) {
if (dy) { if (dy) {
// dy = -(x/y^2) * dout // dy = -(x/y^2) * grad_out
auto tmp0 = pow<T>(y, 2.0); auto tmp0 = pow<T>(y, 2.0);
auto tmp1 = divide<T>(x, tmp0); auto tmp1 = divide<T>(x, tmp0);
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true); auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
...@@ -191,7 +192,7 @@ void divide_grad(const Tensor& x, ...@@ -191,7 +192,7 @@ void divide_grad(const Tensor& x,
} }
} // indicate we will compute dy } // indicate we will compute dy
if (dx) { if (dx) {
// dx = (1/y) * dout // dx = (1/y) * grad_out
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype()); auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y); auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad); auto dx_res = multiply<T>(tmp0, out_grad);
...@@ -303,5 +304,310 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -303,5 +304,310 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
} }
} }
template <typename T>
void matmul_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
bool transpose_x,
bool transpose_y,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> grad_out_dims = vectorize(grad_out.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = grad_out_dims.size();
// Case1 : x's or y's dim = 1
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
Tensor x_help = x;
Tensor y_help = y;
Tensor grad_out_help = grad_out;
reshape_xyout_to_matrixsequence<T>(
x_help, y_help, grad_out_help, transpose_x, transpose_y);
phi::DDim x_grad_dims;
if (x_grad) {
x_grad_dims = x_grad->dims();
if (x_grad_dims != x_help.dims()) {
*x_grad = reshape<T>(*x_grad, IntArray(phi::vectorize(x_help.dims())));
}
}
phi::DDim y_grad_dims;
if (y_grad) {
y_grad_dims = y_grad->dims();
if (y_grad_dims != y_help.dims()) {
*y_grad = reshape<T>(*y_grad, IntArray(phi::vectorize(y_help.dims())));
}
}
phi::DDim dgrad_out_dims;
if (grad_out_grad) {
dgrad_out_dims = grad_out_grad->dims();
if (dgrad_out_dims != grad_out_help.dims()) {
*grad_out_grad = reshape<T>(
*grad_out_grad, IntArray(phi::vectorize(grad_out_help.dims())));
}
}
bool dgrad_out_flag = false;
if (grad_x_grad) {
auto grad_x_grad_mat = grad_x_grad.get();
if (grad_x_grad_mat.dims() != x_help.dims()) {
grad_x_grad_mat = reshape<T>(grad_x_grad_mat,
IntArray(phi::vectorize(x_help.dims())));
}
if (y_grad) {
Tensor y_grad_tmp;
if (transpose_x && transpose_y) {
// y_grad = grad_out' * grad_x_grad'
auto tmp =
modify_dim_for_matmul<T>(grad_out, true, grad_x_grad_mat, false);
y_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), true, true);
} else if (transpose_x) {
// y_grad = grad_x_grad * grad_out
auto tmp =
modify_dim_for_matmul<T>(grad_x_grad_mat, false, grad_out, true);
y_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), false, false);
} else if (transpose_y) {
// y_grad = grad_out' * grad_x_grad
auto tmp =
modify_dim_for_matmul<T>(grad_out, true, grad_x_grad_mat, true);
y_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), true, false);
} else {
// y_grad = grad_x_grad' * grad_out
auto tmp =
modify_dim_for_matmul<T>(grad_x_grad_mat, true, grad_out, true);
y_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), true, false);
}
set_output<T>(y_grad_tmp, y_grad);
}
if (grad_out_grad) {
auto tmp = modify_dim_for_matmul<T>(grad_x_grad_mat, true, y, false);
auto grad_out_grad_tmp = matmul<T>(
std::get<0>(tmp), std::get<1>(tmp), transpose_x, transpose_y);
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
} else if (!grad_x_grad && y_grad) {
auto y_grad_tmp = full<T>(phi::vectorize(y.dims()), Scalar(0.0));
set_output<T>(y_grad_tmp, y_grad);
}
if (grad_y_grad) {
auto grad_y_grad_mat = grad_y_grad.get();
if (grad_y_grad_mat.dims() != y_help.dims()) {
grad_y_grad_mat = reshape<T>(grad_y_grad_mat,
IntArray(phi::vectorize(y_help.dims())));
}
if (x_grad) {
Tensor x_grad_tmp;
if (transpose_x && transpose_y) {
// x_grad = grad_y_grad' * grad_out'
auto tmp =
modify_dim_for_matmul<T>(grad_y_grad_mat, true, grad_out, false);
x_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), true, true);
} else if (transpose_x) {
// x_grad = grad_y_grad * grad_out'
auto tmp =
modify_dim_for_matmul<T>(grad_y_grad_mat, false, grad_out, false);
x_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), false, true);
} else if (transpose_y) {
// x_grad = grad_out * grad_y_grad
auto tmp =
modify_dim_for_matmul<T>(grad_out, false, grad_y_grad_mat, true);
x_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), false, false);
} else {
// x_grad = grad_out * grad_y_grad'
auto tmp =
modify_dim_for_matmul<T>(grad_out, false, grad_y_grad_mat, false);
x_grad_tmp =
matmul<T>(std::get<0>(tmp), std::get<1>(tmp), false, true);
}
set_output<T>(x_grad_tmp, x_grad);
}
if (grad_out_grad) {
auto tmp = modify_dim_for_matmul<T>(x, true, grad_y_grad_mat, false);
auto grad_out_grad_tmp = matmul<T>(
std::get<0>(tmp), std::get<1>(tmp), transpose_x, transpose_y);
auto output_tmp = add<T>(grad_out_grad_tmp, *grad_out_grad);
set_output<T>(output_tmp, grad_out_grad);
}
} else if (!grad_y_grad && x_grad) {
auto x_grad_tmp = full<T>(phi::vectorize(x.dims()), Scalar(0.0));
set_output<T>(x_grad_tmp, x_grad);
}
if (grad_out_grad && !grad_x_grad && !grad_y_grad) {
auto grad_out_grad_tmp =
full<T>(phi::vectorize(grad_out.dims()), Scalar(0.0));
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
if (x_grad) {
if (x_grad_dims != x_help.dims()) {
*x_grad = reshape<T>(*x_grad, IntArray(phi::vectorize(x_grad_dims)));
}
}
if (y_grad) {
if (y_grad_dims != y_help.dims()) {
*y_grad = reshape<T>(*y_grad, IntArray(phi::vectorize(y_grad_dims)));
}
}
if (grad_out_grad) {
if (dgrad_out_dims != grad_out_help.dims()) {
*grad_out_grad = reshape<T>(*grad_out_grad,
IntArray(phi::vectorize(dgrad_out_dims)));
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor x_grad_help;
Tensor y_grad_help;
Tensor grad_out_grad_help;
if (transpose_x) {
if (transpose_y) {
if (x_grad && grad_y_grad) {
x_grad_help = matmul<T>(grad_y_grad.get(), grad_out, true, true);
}
if (y_grad && grad_x_grad) {
y_grad_help = matmul<T>(grad_out, grad_x_grad.get(), true, true);
}
} else {
if (x_grad && grad_y_grad) {
x_grad_help = matmul<T>(grad_y_grad.get(), grad_out, false, true);
}
if (y_grad && grad_x_grad) {
y_grad_help = matmul<T>(grad_x_grad.get(), grad_out, false, false);
}
}
} else {
if (transpose_y) {
if (x_grad && grad_y_grad) {
x_grad_help = matmul<T>(grad_out, grad_y_grad.get(), false, false);
}
if (y_grad && grad_x_grad) {
y_grad_help = matmul<T>(grad_out, grad_x_grad.get(), true, false);
}
} else {
if (x_grad && grad_y_grad) {
x_grad_help = matmul<T>(grad_out, grad_y_grad.get(), false, true);
}
if (y_grad && grad_x_grad) {
y_grad_help = matmul<T>(grad_x_grad.get(), grad_out, true, false);
}
}
}
// get help dims
const std::vector<std::int64_t> x_grad_help_dims =
vectorize(x_grad_help.dims());
const std::vector<std::int64_t> y_grad_help_dims =
vectorize(y_grad_help.dims());
std::vector<std::int64_t> x_grad_broadcast_dims(ndim);
std::vector<std::int64_t> y_grad_broadcast_dims(ndim);
std::fill(x_grad_broadcast_dims.data(),
x_grad_broadcast_dims.data() + ndim - x_ndim,
1);
std::fill(y_grad_broadcast_dims.data(),
y_grad_broadcast_dims.data() + ndim - y_ndim,
1);
std::copy(x_dims.data(),
x_dims.data() + x_ndim,
x_grad_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(),
y_dims.data() + y_ndim,
y_grad_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> x_grad_reduce_dims;
std::vector<int> y_grad_reduce_dims;
for (int ix_grad = 0; ix_grad <= ndim - 3; ix_grad++) {
if (x_grad_help_dims[ix_grad] != 1 &&
x_grad_broadcast_dims[ix_grad] == 1) {
x_grad_reduce_dims.push_back(ix_grad);
}
if (y_grad_help_dims[ix_grad] != 1 &&
y_grad_broadcast_dims[ix_grad] == 1) {
y_grad_reduce_dims.push_back(ix_grad);
}
}
// Reduce sum to get grad by ReduceSum
if (x_grad && x_grad_help.initialized()) {
if (x_grad_reduce_dims.empty()) {
x_grad_help = std::move(x_grad_help);
} else {
x_grad_help = sum<T>(x_grad_help, IntArray(x_grad_reduce_dims));
}
reshape<T>(x_grad_help, IntArray(phi::vectorize(x.dims())));
} else if (x_grad && !x_grad_help.initialized()) {
x_grad_help = full<T>(phi::vectorize(x.dims()), Scalar(0.0));
}
set_output<T>(x_grad_help, x_grad);
if (y_grad && y_grad_help.initialized()) {
if (y_grad_reduce_dims.empty()) {
y_grad_help = std::move(y_grad_help);
} else {
y_grad_help = sum<T>(y_grad_help, IntArray(y_grad_reduce_dims));
}
reshape<T>(y_grad_help, IntArray(phi::vectorize(y.dims())));
} else if (y_grad && !y_grad_help.initialized()) {
y_grad_help = full<T>(phi::vectorize(y.dims()), Scalar(0.0));
}
set_output<T>(y_grad_help, y_grad);
if (grad_out_grad) {
// Calculate the gradient of OutputGrad(Out)
if (grad_x_grad) {
grad_out_grad_help =
matmul<T>(grad_x_grad.get(), y, transpose_x, transpose_y);
}
if (grad_y_grad) {
auto grad_out_grad_help_2 =
matmul<T>(x, grad_y_grad.get(), transpose_x, transpose_y);
grad_out_grad_help = add<T>(grad_out_grad_help, grad_out_grad_help_2);
}
set_output<T>(grad_out_grad_help, grad_out_grad);
}
}
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -17,30 +17,34 @@ ...@@ -17,30 +17,34 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
namespace prim { namespace prim {
// We put some api like utils here // We put some api like utils here
using Tensor = paddle::experimental::Tensor;
template <typename T> template <typename T>
paddle::experimental::Tensor empty(const paddle::experimental::IntArray& shape, Tensor empty(const paddle::experimental::IntArray& shape,
paddle::experimental::DataType dype, paddle::experimental::DataType dype,
const paddle::Place& place); const paddle::Place& place);
template <typename T> template <typename T>
paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x, Tensor empty_like(const Tensor& x,
paddle::experimental::DataType dtype, paddle::experimental::DataType dtype,
const paddle::Place& place); const paddle::Place& place);
// copy tensor for output ptr, in static need use assigh op
template <typename T> template <typename T>
void by_pass(const paddle::experimental::Tensor& x, void by_pass(const Tensor& x, Tensor* out);
paddle::experimental::Tensor* out);
// set output ptr impl with tmp ptr impl,in dygraph OutGradMeta should be set
template <typename T> template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp, void set_output(const Tensor& x_tmp, Tensor* x);
paddle::experimental::Tensor* x);
// These method don't need to be specified // These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
...@@ -78,5 +82,90 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, ...@@ -78,5 +82,90 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
return get_reduce_dims_from_out(out_dims, x_dims); return get_reduce_dims_from_out(out_dims, x_dims);
} }
template <typename T>
std::tuple<Tensor, Tensor> modify_dim_for_matmul(const Tensor& a,
bool is_fold_init_dims_a,
const Tensor& b,
const Tensor* out,
bool is_fold_init_dims_b) {
Tensor a_out = a;
Tensor b_out = b;
bool need_combine =
(a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2;
if (need_combine) {
auto a_dims = a.dims();
auto b_dims = b.dims();
if (is_fold_init_dims_a) {
if (a_dims.size() == 3) {
std::vector<int64_t> a_shape = {a_dims[0] * a_dims[1], a_dims[2]};
a_out = reshape<T>(a_out, IntArray(a_shape));
}
} else {
if (a_dims.size() == 3) {
a_out = transpose<T>(a, IntArray(std::vector<int>({1, 0, 2})));
std::vector<int64_t> a_shape = {a_dims[0], a_dims[1] * a_dims[2]};
a_out = reshape<T>(a_out, IntArray(a_shape));
}
}
if (is_fold_init_dims_b) {
if (b_dims.size() == 3) {
std::vector<int64_t> b_shape = {b_dims[0] * b_dims[1], b_dims[2]};
b_out = reshape<T>(b_out, IntArray(b_shape));
}
} else {
if (b_dims.size() == 3) {
b_out = transpose<T>(b, IntArray(std::vector<int>({1, 0, 2})));
std::vector<int64_t> b_shape = {b_dims[0], b_dims[1] * b_dims[2]};
b_out = reshape<T>(b_out, IntArray(b_shape));
}
}
}
std::tuple<Tensor, Tensor> output(a_out, b_out);
return output;
}
template <typename T>
void reshape_tensor_to_matrixsequence(
Tensor* x, const phi::funcs::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
*x = reshape<T>(*x, std::vector<int64_t>({descriptor.batch_size_, h, w}));
} else {
*x = reshape<T>(*x, std::vector<int64_t>({h, w}));
}
}
template <typename T>
void reshape_xyout_to_matrixsequence(
Tensor* x, Tensor* y, Tensor* out, bool trans_x, bool trans_y) {
if (x->dims().size() == 1) {
*x = reshape<T>(*x, std::vector<int64_t>({1, x->dims()[0]}));
}
if (y->dims().size() == 1) {
*y = reshape<T>(*y, std::vector<int64_t>({y->dims()[0], 1}));
}
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
*out = reshape<T>(
*out, std::vector<int64_t>({mat_dim_x.height_, mat_dim_y.width_}));
} else {
*out = reshape<T>(*out,
std::vector<int64_t>({(std::max)(mat_dim_x.batch_size_,
mat_dim_y.batch_size_),
mat_dim_x.height_,
mat_dim_y.width_}));
}
reshape_tensor_to_matrixsequence<T>(x, mat_dim_x);
reshape_tensor_to_matrixsequence<T>(y, mat_dim_y);
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册