未验证 提交 71f74f5c 编写于 作者: J Jacek Czaja 提交者: GitHub

Fix to CI (#44744)

* - fix

* - another fix

* lint
上级 2a8219c1
......@@ -103,6 +103,34 @@ static paddle::framework::DDim ColumnMatrixDimsFromVector(
return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1});
}
phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
......@@ -235,36 +263,6 @@ class MatMulMKLDNNHandler
out_strides;
};
phi::DDim GetDimForInput(const ExecutionContext &ctx,
std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
std::pair<phi::funcs::MatDescriptor, memory::dims> GetInputDimsAndStrides(
const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
......@@ -599,6 +597,23 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
new_dims = input_dims.reshape(shape).transpose(axis);
}
......@@ -693,18 +708,6 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
out->set_format(format);
}
paddle::framework::DDim GetDimForInput(
const paddle::framework::ExecutionContext &ctx,
const std::string &input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.Input<paddle::framework::Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册