未验证 提交 5a55f13b 编写于 作者: W wangzhen38 提交者: GitHub

【code format】Fix cpplint style 4 (#43695)

* cpplint fix 2

* cpplint fix 2

* fix cpplint style 4

* fix cpplint style 4

* fix cpplint style 4

* fix cpplint style 4
上级 75080988
...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef POLY_UTIL_CC_ #ifndef POLY_UTIL_CC_
#define POLY_UTIL_CC_ #define POLY_UTIL_CC_
...@@ -26,61 +25,62 @@ using gpc::gpc_free_polygon; ...@@ -26,61 +25,62 @@ using gpc::gpc_free_polygon;
using gpc::gpc_polygon_clip; using gpc::gpc_polygon_clip;
template <class T> template <class T>
void Array2PointVec(const T*& box, const size_t box_size, void Array2PointVec(const T* box,
std::vector<Point_<T>>& vec) { const size_t box_size,
std::vector<Point_<T>>* vec) {
size_t pts_num = box_size / 2; size_t pts_num = box_size / 2;
vec.resize(pts_num); (*vec).resize(pts_num);
for (size_t i = 0; i < pts_num; i++) { for (size_t i = 0; i < pts_num; i++) {
vec.at(i).x = box[2 * i]; (*vec).at(i).x = box[2 * i];
vec.at(i).y = box[2 * i + 1]; (*vec).at(i).y = box[2 * i + 1];
} }
} }
template <class T> template <class T>
void Array2Poly(const T*& box, const size_t box_size, gpc::gpc_polygon& poly) { void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly) {
size_t pts_num = box_size / 2; size_t pts_num = box_size / 2;
poly.num_contours = 1; (*poly).num_contours = 1;
poly.hole = (int*)malloc(sizeof(int)); (*poly).hole = reinterpret_cast<int*>(malloc(sizeof(int)));
poly.hole[0] = 0; (*poly).hole[0] = 0;
poly.contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list)); (*poly).contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly.contour->num_vertices = pts_num; (*poly).contour->num_vertices = pts_num;
poly.contour->vertex = (*poly).contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num); (gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) { for (size_t i = 0; i < pts_num; ++i) {
poly.contour->vertex[i].x = box[2 * i]; (*poly).contour->vertex[i].x = box[2 * i];
poly.contour->vertex[i].y = box[2 * i + 1]; (*poly).contour->vertex[i].y = box[2 * i + 1];
} }
} }
template <class T> template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon& poly) { void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon* poly) {
int pts_num = vec.size(); int pts_num = vec.size();
poly.num_contours = 1; (*poly).num_contours = 1;
poly.hole = (int*)malloc(sizeof(int)); (*poly).hole = reinterpret_cast<int*>(malloc(sizeof(int)));
poly.hole[0] = 0; (*poly).hole[0] = 0;
poly.contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list)); (*poly).contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly.contour->num_vertices = pts_num; (*poly).contour->num_vertices = pts_num;
poly.contour->vertex = (*poly).contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num); (gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) { for (size_t i = 0; i < pts_num; ++i) {
poly.contour->vertex[i].x = vec[i].x; (*poly).contour->vertex[i].x = vec[i].x;
poly.contour->vertex[i].y = vec[i].y; (*poly).contour->vertex[i].y = vec[i].y;
} }
} }
template <class T> template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour, void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>& vec) { std::vector<Point_<T>>* vec) {
int pts_num = contour.num_vertices; int pts_num = contour.num_vertices;
vec.resize(pts_num); (*vec).resize(pts_num);
for (int i = 0; i < pts_num; i++) { for (int i = 0; i < pts_num; i++) {
vec.at(i).x = contour.vertex[i].x; (*vec).at(i).x = contour.vertex[i].x;
vec.at(i).y = contour.vertex[i].y; (*vec).at(i).y = contour.vertex[i].y;
} }
} }
template <class T> template <class T>
T GetContourArea(std::vector<Point_<T>>& vec) { T GetContourArea(const std::vector<Point_<T>>& vec) {
size_t pts_num = vec.size(); size_t pts_num = vec.size();
if (pts_num < 3) return T(0.); if (pts_num < 3) return T(0.);
T area = T(0.); T area = T(0.);
...@@ -96,17 +96,19 @@ T PolyArea(const T* box, const size_t box_size, const bool normalized) { ...@@ -96,17 +96,19 @@ T PolyArea(const T* box, const size_t box_size, const bool normalized) {
// If coordinate values are is invalid // If coordinate values are is invalid
// if area size <= 0, return 0. // if area size <= 0, return 0.
std::vector<Point_<T>> vec; std::vector<Point_<T>> vec;
Array2PointVec<T>(box, box_size, vec); Array2PointVec<T>(box, box_size, &vec);
return GetContourArea<T>(vec); return GetContourArea<T>(vec);
} }
template <class T> template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size, T PolyOverlapArea(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) { const bool normalized) {
gpc::gpc_polygon poly1; gpc::gpc_polygon poly1;
gpc::gpc_polygon poly2; gpc::gpc_polygon poly2;
Array2Poly<T>(box1, box_size, poly1); Array2Poly<T>(box1, box_size, &poly1);
Array2Poly<T>(box2, box_size, poly2); Array2Poly<T>(box2, box_size, &poly2);
gpc::gpc_polygon respoly; gpc::gpc_polygon respoly;
gpc::gpc_op op = gpc::GPC_INT; gpc::gpc_op op = gpc::GPC_INT;
gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly); gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly);
...@@ -115,7 +117,7 @@ T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size, ...@@ -115,7 +117,7 @@ T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
int contour_num = respoly.num_contours; int contour_num = respoly.num_contours;
for (int i = 0; i < contour_num; ++i) { for (int i = 0; i < contour_num; ++i) {
std::vector<Point_<T>> resvec; std::vector<Point_<T>> resvec;
Poly2PointVec<T>(respoly.contour[i], resvec); Poly2PointVec<T>(respoly.contour[i], &resvec);
// inter_area += std::fabs(cv::contourArea(resvec)) + 0.5f * // inter_area += std::fabs(cv::contourArea(resvec)) + 0.5f *
// (cv::arcLength(resvec, true)); // (cv::arcLength(resvec, true));
inter_area += GetContourArea<T>(resvec); inter_area += GetContourArea<T>(resvec);
......
...@@ -11,9 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#ifndef POLY_UTIL_H_
#define POLY_UTIL_H_
#include <vector> #include <vector>
...@@ -44,31 +42,32 @@ class Point_ { ...@@ -44,31 +42,32 @@ class Point_ {
}; };
template <class T> template <class T>
void Array2PointVec(const T*& box, const size_t box_size, void Array2PointVec(const T* box,
std::vector<Point_<T>>& vec); const size_t box_size,
std::vector<Point_<T>>* vec);
template <class T> template <class T>
void Array2Poly(const T*& box, const size_t box_size, gpc::gpc_polygon& poly); void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly);
template <class T> template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon& poly); void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon* poly);
template <class T> template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour, void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>& vec); std::vector<Point_<T>>* vec);
template <class T> template <class T>
T GetContourArea(std::vector<Point_<T>>& vec); T GetContourArea(const std::vector<Point_<T>>& vec);
template <class T> template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized); T PolyArea(const T* box, const size_t box_size, const bool normalized);
template <class T> template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size, T PolyOverlapArea(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized); const bool normalized);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#include "paddle/fluid/operators/detection/poly_util.cc" #include "paddle/fluid/operators/detection/poly_util.cc"
#endif // POLY_UTIL_H_
...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
namespace { namespace {
...@@ -49,19 +48,23 @@ static std::vector<int64_t> Transpose(const std::vector<int64_t>& x, ...@@ -49,19 +48,23 @@ static std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
size_t axis_size = axis.size(); size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end()); auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique.")); "In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, axis_size, PADDLE_ENFORCE_EQ(in_rank,
axis_size,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The input dimension's size " "The input dimension's size "
"should be equal to the axis's size. " "should be equal to the axis's size. "
"But received dimension is %d, " "But received dimension is %d, "
"axis's size is %d", "axis's size is %d",
in_rank, axis_size)); in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1).")); "Axis values must be ranging from 0 to (dims - 1)."));
...@@ -85,7 +88,8 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx, ...@@ -85,7 +88,8 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx,
auto& MatrixDimsFromVector = auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, MatrixDimsFromVector(new_dims),
0,
ctx.Attr<bool>(std::string("trans_") + ctx.Attr<bool>(std::string("trans_") +
static_cast<char>(std::tolower(input_name[0])))); static_cast<char>(std::tolower(input_name[0]))));
...@@ -125,16 +129,27 @@ template <typename T> ...@@ -125,16 +129,27 @@ template <typename T>
void ExecuteMatMulV2(const ExecutionContext& ctx, void ExecuteMatMulV2(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine, const dnnl::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x, paddle::platform::Place cpu_place,
std::vector<int64_t>& x_dims, bool trans_x, const Tensor* x,
const Tensor* y, std::vector<int64_t>& y_dims, const std::vector<int64_t>& x_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims, bool trans_x,
const Tensor* y,
const std::vector<int64_t>& y_dims,
bool trans_y,
Tensor* out,
const std::vector<int64_t>& out_dims,
int execution_number = 0) { int execution_number = 0) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X"); std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y"); std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims, MatMulV2MKLDNNHandler<T> handler(onednn_engine,
trans_x, y_dims, trans_y, IsOutputFused(ctx), ctx.GetPlace(),
x_strides_override, y_strides_override); x_dims,
trans_x,
y_dims,
trans_y,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
...@@ -177,44 +192,48 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -177,44 +192,48 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
void CalculateMatrixDims(const ExecutionContext& ctx, void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims, const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims, const std::vector<int64_t>& y_dims,
std::vector<int64_t>& x_bd_dims, std::vector<int64_t>* x_bd_dims,
std::vector<int64_t>& y_bd_dims, std::vector<int64_t>* y_bd_dims,
std::vector<int64_t>& out_dims, Tensor* out) const { std::vector<int64_t>* out_dims,
Tensor* out) const {
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) { } else if (x_dims.size() == 2) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[1]; (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
x_bd_dims[x_bd_dims.size() - 2] = x_dims[0]; (*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
} else { } else {
for (size_t i = 0; i < x_dims.size(); ++i) { for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[x_bd_dims.size() - x_dims.size() + i] = x_dims[i]; (*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
} }
} }
if (y_dims.size() == 1) { if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; (*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) { } else if (y_dims.size() == 2) {
y_bd_dims[y_bd_dims.size() - 1] = y_dims[1]; (*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
y_bd_dims[y_bd_dims.size() - 2] = y_dims[0]; (*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
} else { } else {
for (size_t i = 0; i < y_dims.size(); ++i) { for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[y_bd_dims.size() - y_dims.size() + i] = y_dims[i]; (*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
} }
} }
if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) { if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) {
for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) { for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 || (*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
y_bd_dims[i] == 1, (*y_bd_dims)[i] == 1,
true, true,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting." "Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but " "Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d", "received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_bd_dims[i], i, y_bd_dims[i])); i,
out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]); (*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(*out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
} }
out->Resize(phi::make_ddim(out_dims)); out->Resize(phi::make_ddim((*out_dims)));
} }
} }
...@@ -238,11 +257,20 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -238,11 +257,20 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
std::vector<int64_t> x_bd_dims(ndims, 1); std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1); std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, CalculateMatrixDims(
out); ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, ExecuteMatMulV2<T>(ctx,
x_bd_dims, trans_x, y, y_bd_dims, trans_y, out, dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out,
out_dims); out_dims);
} }
}; };
...@@ -253,36 +281,46 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -253,36 +281,46 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
private: private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp, void CalculateGradMatrixDims(const ExecutionContext& ctx,
Tensor* dx_tmp,
Tensor* dy_tmp, Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims, const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims, const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims, std::vector<int64_t>* dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const { std::vector<int64_t>* dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) { for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) { if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) { if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i]; (*dx_bd_dims)[i] = dy_dims[i];
} else { } else {
dy_bd_dims[i] = dx_dims[i]; (*dy_bd_dims)[i] = dx_dims[i];
} }
} }
} }
dx_tmp->Resize(phi::make_ddim(dx_bd_dims)); dx_tmp->Resize(phi::make_ddim((*dx_bd_dims)));
dx_tmp->mutable_data<T>(ctx.GetPlace()); dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(phi::make_ddim(dy_bd_dims)); dy_tmp->Resize(phi::make_ddim((*dy_bd_dims)));
dy_tmp->mutable_data<T>(ctx.GetPlace()); dy_tmp->mutable_data<T>(ctx.GetPlace());
} }
void ReduceSumForMatmulGradOutput( void ReduceSumForMatmulGradOutput(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx,
const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx, const MKLDNNDeviceContext& dev_ctx,
std::vector<int64_t>& dx_dims, const dnnl::engine onednn_engine,
const Tensor* dx_tmp,
Tensor* dx,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& squeezed_dims) const { const std::vector<int64_t>& squeezed_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler( paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, dnnl::algorithm::reduction_sum,
ctx.GetPlace(), dx_tmp, dx, dx_dims); 0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx); auto dst_memory_p = handler.AcquireDstMemory(dx);
...@@ -326,8 +364,8 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -326,8 +364,8 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
} else if (x_dims.size() != y_dims.size()) { } else if (x_dims.size() != y_dims.size()) {
is_broadcast = true; is_broadcast = true;
} else { } else {
is_broadcast = is_broadcast = !std::equal(x_dims.cbegin(),
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2, x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin()); y_dims.cbegin());
} }
...@@ -362,44 +400,138 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -362,44 +400,138 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
std::vector<int64_t> dx_bd_dims(x_dims); std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims); std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims, CalculateGradMatrixDims(
dy_bd_dims); ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (trans_x && trans_y) { if (trans_x && trans_y) {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims, ExecuteMatMulV2<T>(ctx,
true, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1); dev_ctx,
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, onednn_engine,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims, ctx.GetPlace(),
y,
y_dims,
true,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
true,
&dy_tmp,
dy_bd_dims,
2); 2);
} else if (trans_x) { } else if (trans_x) {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims, ExecuteMatMulV2<T>(ctx,
false, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1); dev_ctx,
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims, onednn_engine,
false, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2); ctx.GetPlace(),
y,
y_dims,
false,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
false,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
} else if (trans_y) { } else if (trans_y) {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, ExecuteMatMulV2<T>(ctx,
dout_dims, false, y, y_dims, false, &dx_tmp, dev_ctx,
dx_bd_dims, 1); onednn_engine,
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, ctx.GetPlace(),
dout_dims, true, x, x_dims, false, &dy_tmp, dy_bd_dims, dout,
dout_dims,
false,
y,
y_dims,
false,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
false,
&dy_tmp,
dy_bd_dims,
2); 2);
} else { } else {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, ExecuteMatMulV2<T>(ctx,
dout_dims, false, y, y_dims, true, &dx_tmp, dx_bd_dims, dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
true,
&dx_tmp,
dx_bd_dims,
1); 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims, ExecuteMatMulV2<T>(ctx,
true, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2); dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
true,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
} }
if (x_dims != dx_bd_dims) { if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, ReduceSumForMatmulGradOutput(ctx,
x_dims, phi::vectorize(x->dims())); dev_ctx,
onednn_engine,
&dx_tmp,
dx,
x_dims,
phi::vectorize(x->dims()));
} else { } else {
*dx = std::move(dx_tmp); *dx = std::move(dx_tmp);
} }
if (y_dims != dy_bd_dims) { if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, ReduceSumForMatmulGradOutput(ctx,
y_dims, phi::vectorize(y->dims())); dev_ctx,
onednn_engine,
&dy_tmp,
dy,
y_dims,
phi::vectorize(y->dims()));
} else { } else {
*dy = std::move(dy_tmp); *dy = std::move(dy_tmp);
} }
...@@ -413,10 +545,14 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -413,10 +545,14 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
}; };
} // anonymous namespace } // anonymous namespace
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>, MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>); MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2_grad,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2GradMKLDNNKernel<float>, MatMulV2GradMKLDNNKernel<float>,
MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>); MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册