未验证 提交 5a55f13b 编写于 作者: W wangzhen38 提交者: GitHub

【code format】Fix cpplint style 4 (#43695)

* cpplint fix 2

* cpplint fix 2

* fix cpplint style 4

* fix cpplint style 4

* fix cpplint style 4

* fix cpplint style 4
上级 75080988
......@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef POLY_UTIL_CC_
#define POLY_UTIL_CC_
......@@ -26,61 +25,62 @@ using gpc::gpc_free_polygon;
using gpc::gpc_polygon_clip;
template <class T>
void Array2PointVec(const T*& box, const size_t box_size,
std::vector<Point_<T>>& vec) {
void Array2PointVec(const T* box,
const size_t box_size,
std::vector<Point_<T>>* vec) {
size_t pts_num = box_size / 2;
vec.resize(pts_num);
(*vec).resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
vec.at(i).x = box[2 * i];
vec.at(i).y = box[2 * i + 1];
(*vec).at(i).x = box[2 * i];
(*vec).at(i).y = box[2 * i + 1];
}
}
template <class T>
void Array2Poly(const T*& box, const size_t box_size, gpc::gpc_polygon& poly) {
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly) {
size_t pts_num = box_size / 2;
poly.num_contours = 1;
poly.hole = (int*)malloc(sizeof(int));
poly.hole[0] = 0;
poly.contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly.contour->num_vertices = pts_num;
poly.contour->vertex =
(*poly).num_contours = 1;
(*poly).hole = reinterpret_cast<int*>(malloc(sizeof(int)));
(*poly).hole[0] = 0;
(*poly).contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
(*poly).contour->num_vertices = pts_num;
(*poly).contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
poly.contour->vertex[i].x = box[2 * i];
poly.contour->vertex[i].y = box[2 * i + 1];
(*poly).contour->vertex[i].x = box[2 * i];
(*poly).contour->vertex[i].y = box[2 * i + 1];
}
}
template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon& poly) {
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon* poly) {
int pts_num = vec.size();
poly.num_contours = 1;
poly.hole = (int*)malloc(sizeof(int));
poly.hole[0] = 0;
poly.contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly.contour->num_vertices = pts_num;
poly.contour->vertex =
(*poly).num_contours = 1;
(*poly).hole = reinterpret_cast<int*>(malloc(sizeof(int)));
(*poly).hole[0] = 0;
(*poly).contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
(*poly).contour->num_vertices = pts_num;
(*poly).contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
poly.contour->vertex[i].x = vec[i].x;
poly.contour->vertex[i].y = vec[i].y;
(*poly).contour->vertex[i].x = vec[i].x;
(*poly).contour->vertex[i].y = vec[i].y;
}
}
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>& vec) {
std::vector<Point_<T>>* vec) {
int pts_num = contour.num_vertices;
vec.resize(pts_num);
(*vec).resize(pts_num);
for (int i = 0; i < pts_num; i++) {
vec.at(i).x = contour.vertex[i].x;
vec.at(i).y = contour.vertex[i].y;
(*vec).at(i).x = contour.vertex[i].x;
(*vec).at(i).y = contour.vertex[i].y;
}
}
template <class T>
T GetContourArea(std::vector<Point_<T>>& vec) {
T GetContourArea(const std::vector<Point_<T>>& vec) {
size_t pts_num = vec.size();
if (pts_num < 3) return T(0.);
T area = T(0.);
......@@ -96,17 +96,19 @@ T PolyArea(const T* box, const size_t box_size, const bool normalized) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
std::vector<Point_<T>> vec;
Array2PointVec<T>(box, box_size, vec);
Array2PointVec<T>(box, box_size, &vec);
return GetContourArea<T>(vec);
}
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
T PolyOverlapArea(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) {
gpc::gpc_polygon poly1;
gpc::gpc_polygon poly2;
Array2Poly<T>(box1, box_size, poly1);
Array2Poly<T>(box2, box_size, poly2);
Array2Poly<T>(box1, box_size, &poly1);
Array2Poly<T>(box2, box_size, &poly2);
gpc::gpc_polygon respoly;
gpc::gpc_op op = gpc::GPC_INT;
gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly);
......@@ -115,7 +117,7 @@ T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
int contour_num = respoly.num_contours;
for (int i = 0; i < contour_num; ++i) {
std::vector<Point_<T>> resvec;
Poly2PointVec<T>(respoly.contour[i], resvec);
Poly2PointVec<T>(respoly.contour[i], &resvec);
// inter_area += std::fabs(cv::contourArea(resvec)) + 0.5f *
// (cv::arcLength(resvec, true));
inter_area += GetContourArea<T>(resvec);
......
......@@ -11,9 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef POLY_UTIL_H_
#define POLY_UTIL_H_
#pragma once
#include <vector>
......@@ -44,31 +42,32 @@ class Point_ {
};
template <class T>
void Array2PointVec(const T*& box, const size_t box_size,
std::vector<Point_<T>>& vec);
void Array2PointVec(const T* box,
const size_t box_size,
std::vector<Point_<T>>* vec);
template <class T>
void Array2Poly(const T*& box, const size_t box_size, gpc::gpc_polygon& poly);
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly);
template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon& poly);
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon* poly);
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>& vec);
std::vector<Point_<T>>* vec);
template <class T>
T GetContourArea(std::vector<Point_<T>>& vec);
T GetContourArea(const std::vector<Point_<T>>& vec);
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized);
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
T PolyOverlapArea(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized);
} // namespace operators
} // namespace paddle
#include "paddle/fluid/operators/detection/poly_util.cc"
#endif // POLY_UTIL_H_
......@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
namespace {
......@@ -49,19 +48,23 @@ static std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, axis_size,
PADDLE_ENFORCE_EQ(in_rank,
axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
......@@ -85,7 +88,8 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx,
auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0,
MatrixDimsFromVector(new_dims),
0,
ctx.Attr<bool>(std::string("trans_") +
static_cast<char>(std::tolower(input_name[0]))));
......@@ -125,16 +129,27 @@ template <typename T>
void ExecuteMatMulV2(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
paddle::platform::Place cpu_place,
const Tensor* x,
const std::vector<int64_t>& x_dims,
bool trans_x,
const Tensor* y,
const std::vector<int64_t>& y_dims,
bool trans_y,
Tensor* out,
const std::vector<int64_t>& out_dims,
int execution_number = 0) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y, IsOutputFused(ctx),
x_strides_override, y_strides_override);
MatMulV2MKLDNNHandler<T> handler(onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......@@ -177,44 +192,48 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
std::vector<int64_t>& x_bd_dims,
std::vector<int64_t>& y_bd_dims,
std::vector<int64_t>& out_dims, Tensor* out) const {
std::vector<int64_t>* x_bd_dims,
std::vector<int64_t>* y_bd_dims,
std::vector<int64_t>* out_dims,
Tensor* out) const {
if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[1];
x_bd_dims[x_bd_dims.size() - 2] = x_dims[0];
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
(*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[x_bd_dims.size() - x_dims.size() + i] = x_dims[i];
(*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
(*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[y_bd_dims.size() - 1] = y_dims[1];
y_bd_dims[y_bd_dims.size() - 2] = y_dims[0];
(*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
(*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[y_bd_dims.size() - y_dims.size() + i] = y_dims[i];
(*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
}
}
if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) {
for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) {
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 ||
y_bd_dims[i] == 1,
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1,
true,
paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_bd_dims[i], i, y_bd_dims[i]));
out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]);
i,
(*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(*out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
out->Resize(phi::make_ddim(out_dims));
out->Resize(phi::make_ddim((*out_dims)));
}
}
......@@ -238,11 +257,20 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_bd_dims, trans_x, y, y_bd_dims, trans_y, out,
CalculateMatrixDims(
ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out,
out_dims);
}
};
......@@ -253,36 +281,46 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp,
void CalculateGradMatrixDims(const ExecutionContext& ctx,
Tensor* dx_tmp,
Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const {
std::vector<int64_t>* dx_bd_dims,
std::vector<int64_t>* dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i];
(*dx_bd_dims)[i] = dy_dims[i];
} else {
dy_bd_dims[i] = dx_dims[i];
(*dy_bd_dims)[i] = dx_dims[i];
}
}
}
dx_tmp->Resize(phi::make_ddim(dx_bd_dims));
dx_tmp->Resize(phi::make_ddim((*dx_bd_dims)));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(phi::make_ddim(dy_bd_dims));
dy_tmp->Resize(phi::make_ddim((*dy_bd_dims)));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t>& dx_dims,
const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine,
const Tensor* dx_tmp,
Tensor* dx,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& squeezed_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, dx_dims);
dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
......@@ -326,9 +364,9 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast =
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
is_broadcast = !std::equal(x_dims.cbegin(),
x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}
// if no broadcasting is needed, we can simply use matmul's grad and avoid
......@@ -362,44 +400,138 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims,
dy_bd_dims);
CalculateGradMatrixDims(
ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (trans_x && trans_y) {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
true, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
true,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
true,
&dy_tmp,
dy_bd_dims,
2);
} else if (trans_x) {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
false, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
false, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
false,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
false,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
} else if (trans_y) {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp, dy_bd_dims,
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
false,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
} else {
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp, dx_bd_dims,
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
true, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
true,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims, phi::vectorize(x->dims()));
ReduceSumForMatmulGradOutput(ctx,
dev_ctx,
onednn_engine,
&dx_tmp,
dx,
x_dims,
phi::vectorize(x->dims()));
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims, phi::vectorize(y->dims()));
ReduceSumForMatmulGradOutput(ctx,
dev_ctx,
onednn_engine,
&dy_tmp,
dy,
y_dims,
phi::vectorize(y->dims()));
} else {
*dy = std::move(dy_tmp);
}
......@@ -413,10 +545,14 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
};
} // anonymous namespace
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(matmul_v2_grad,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2GradMKLDNNKernel<float>,
MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册