Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e7afa391
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7afa391
编写于
3月 05, 2022
作者:
C
Chen Weihang
提交者:
GitHub
3月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Remove eig op depend for svd_helper (#40174)
* remove eig dep for svd helper * fix win failed
上级
4be5448b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
379 addition
and
44 deletion
+379
-44
paddle/fluid/operators/eig_op.h
paddle/fluid/operators/eig_op.h
+54
-38
paddle/phi/kernels/complex_kernel.h
paddle/phi/kernels/complex_kernel.h
+56
-4
paddle/phi/kernels/funcs/diag_functor.h
paddle/phi/kernels/funcs/diag_functor.h
+99
-0
paddle/phi/kernels/funcs/slice.h
paddle/phi/kernels/funcs/slice.h
+127
-0
paddle/phi/kernels/funcs/unsqueeze.h
paddle/phi/kernels/funcs/unsqueeze.h
+41
-0
paddle/phi/kernels/matmul_kernel.h
paddle/phi/kernels/matmul_kernel.h
+2
-2
未找到文件。
paddle/fluid/operators/eig_op.h
浏览文件 @
e7afa391
...
...
@@ -18,12 +18,19 @@
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#define EPSILON 1e-6
namespace
paddle
{
...
...
@@ -214,12 +221,17 @@ class EigKernel : public framework::OpKernel<T> {
ApplyEigKernel
<
DeviceContext
,
phi
::
dtype
::
Real
<
T
>>
(
*
x
,
&
real_values
,
&
real_vectors
,
context
);
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
phi
::
dtype
::
Real
<
T
>
,
Tout
>
(
context
);
auto
&
orig_dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
orig_dev_ctx
);
// 1. extract real part & imag part from real_values
Tensor
real_part
=
dito
.
Slice
(
real_values
,
{
-
1
},
{
0
},
{
order
});
Tensor
imag_part
=
dito
.
Slice
(
real_values
,
{
-
1
},
{
order
},
{
order
*
2
});
Tensor
real_part
=
phi
::
funcs
::
Slice
<
T
>
(
dev_ctx
,
real_values
,
{
-
1
},
{
0
},
{
order
});
Tensor
imag_part
=
phi
::
funcs
::
Slice
<
T
>
(
dev_ctx
,
real_values
,
{
-
1
},
{
order
},
{
order
*
2
});
// 2. construct complex values
auto
*
real_part_data
=
real_part
.
data
<
phi
::
dtype
::
Real
<
T
>>
();
...
...
@@ -233,7 +245,8 @@ class EigKernel : public framework::OpKernel<T> {
for_range
(
functor
);
// 3. construct complex vectors
Tensor
real_vector_trans
=
dito
.
Transpose
(
real_vectors
);
Tensor
real_vector_trans
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
real_vectors
);
Tensor
out_vectors_trans
;
out_vectors_trans
.
mutable_data
<
Tout
>
(
x
->
dims
(),
context
.
GetPlace
());
ConstructComplexVectors
<
phi
::
dtype
::
Real
<
T
>
,
Tout
>
(
...
...
@@ -251,45 +264,48 @@ class EigKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
out
>
template
<
typename
DeviceContext
,
typename
T
>
void
ComputeBackwardForComplexInput
(
const
Tensor
&
V
,
const
Tensor
&
L
,
const
Tensor
&
gL
,
const
Tensor
&
gV
,
T
out
*
x_grad_data
,
int
batch_count
,
int
order
,
T
*
x_grad_data
,
int
batch_count
,
int
order
,
const
framework
::
ExecutionContext
&
context
)
{
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
Tout
,
Tout
>
(
context
);
Tensor
trans_v
=
dito
.
Transpose
(
V
);
Tensor
Vh
=
dito
.
Conj
(
trans_v
);
Tensor
Lconj
=
dito
.
Conj
(
L
);
Tensor
Econj
=
dito
.
Sub
(
dito
.
Unsqueeze
(
Lconj
,
-
2
),
dito
.
Unsqueeze
(
Lconj
,
-
1
));
Tensor
VhgV
=
dito
.
Matmul
(
Vh
,
gV
);
Tensor
diag_real
=
dito
.
Real
(
VhgV
);
Tensor
diag_res
=
dito
.
BatchDiag
(
diag_real
,
batch_count
);
Tensor
diag_unsqueezed
=
dito
.
Unsqueeze
(
diag_res
,
-
2
);
auto
&
orig_dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
orig_dev_ctx
);
Tensor
trans_v
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
V
);
Tensor
Vh
=
phi
::
Conj
<
T
>
(
dev_ctx
,
trans_v
);
Tensor
Lconj
=
phi
::
Conj
<
T
>
(
dev_ctx
,
L
);
Tensor
Econj
=
phi
::
Subtract
<
T
>
(
dev_ctx
,
phi
::
funcs
::
Unsqueeze
(
Lconj
,
-
2
),
phi
::
funcs
::
Unsqueeze
(
Lconj
,
-
1
));
Tensor
VhgV
=
phi
::
Matmul
<
T
>
(
dev_ctx
,
Vh
,
gV
);
Tensor
diag_real
=
phi
::
Real
<
T
>
(
dev_ctx
,
VhgV
);
Tensor
diag_res
=
phi
::
funcs
::
BatchDiag
<
T
>
(
dev_ctx
,
diag_real
,
batch_count
);
Tensor
diag_unsqueezed
=
phi
::
funcs
::
Unsqueeze
(
diag_res
,
-
2
);
// turn diag_unsqueezed into complex
auto
numel
=
diag_unsqueezed
.
numel
();
Tensor
diag_unsqueezed_complex
;
auto
*
data_diag_un
=
diag_unsqueezed
.
data
<
phi
::
dtype
::
Real
<
T
out
>>
();
auto
*
data_diag_un_com
=
diag_unsqueezed_complex
.
mutable_data
<
T
out
>
(
auto
*
data_diag_un
=
diag_unsqueezed
.
data
<
phi
::
dtype
::
Real
<
T
>>
();
auto
*
data_diag_un_com
=
diag_unsqueezed_complex
.
mutable_data
<
T
>
(
diag_unsqueezed
.
dims
(),
context
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
out
)));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
phi
::
funcs
::
RealToComplexFunctor
<
T
out
>
functor
(
data_diag_un
,
data_diag_un_com
,
numel
);
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
platform
::
ForRange
<
DeviceContext
>
for_range
(
orig_
dev_ctx
,
numel
);
phi
::
funcs
::
RealToComplexFunctor
<
T
>
functor
(
data_diag_un
,
data_diag_un_com
,
numel
);
for_range
(
functor
);
// real tensor multiply complex tensor in broadcast manner
Tensor
res1
=
dito
.
RealMulComplex
(
V
,
diag_unsqueezed_complex
);
Tensor
res2
=
dito
.
Matmul
(
Vh
,
res1
);
Tensor
result
=
dito
.
Sub
(
VhgV
,
res2
);
Tensor
res1
=
phi
::
Multiply
<
T
>
(
dev_ctx
,
V
,
diag_unsqueezed_complex
);
Tensor
res2
=
phi
::
Matmul
<
T
>
(
dev_ctx
,
Vh
,
res1
);
Tensor
result
=
phi
::
Subtract
<
T
>
(
dev_ctx
,
VhgV
,
res2
);
result
.
mutable_data
<
Tout
>
(
V
.
dims
(),
context
.
GetPlace
());
result
=
dito
.
Div
(
result
,
Econj
);
result
=
dito
.
DiagFill
(
order
,
order
,
order
,
0
,
gL
,
result
);
Tensor
rhs
=
dito
.
Matmul
(
result
,
Vh
);
result
.
mutable_data
<
T
>
(
V
.
dims
(),
context
.
GetPlace
());
result
=
phi
::
Divide
<
T
>
(
dev_ctx
,
result
,
Econj
);
result
=
phi
::
funcs
::
DiagFill
<
T
,
T
>
(
dev_ctx
,
order
,
order
,
order
,
0
,
gL
,
result
);
Tensor
rhs
=
phi
::
Matmul
<
T
>
(
dev_ctx
,
result
,
Vh
);
// solve linear system
// solve(Vh, rhs, out, m, k)
...
...
@@ -298,10 +314,10 @@ void ComputeBackwardForComplexInput(
// x_grad: out
int
m
=
Vh
.
dims
()[
Vh
.
dims
().
size
()
-
1
];
int
k
=
rhs
.
dims
()[
rhs
.
dims
().
size
()
-
1
];
auto
*
matrix_data
=
Vh
.
data
<
T
out
>
();
auto
*
rhs_data
=
rhs
.
data
<
T
out
>
();
math
::
SolveLinearSystem
<
T
out
>
(
matrix_data
,
rhs_data
,
x_grad_data
,
m
,
k
,
batch_count
);
auto
*
matrix_data
=
Vh
.
data
<
T
>
();
auto
*
rhs_data
=
rhs
.
data
<
T
>
();
math
::
SolveLinearSystem
<
T
>
(
matrix_data
,
rhs_data
,
x_grad_data
,
m
,
k
,
batch_count
);
}
template
<
typename
DeviceContext
,
typename
T
,
typename
Tout
>
...
...
paddle/phi/kernels/complex_kernel.h
浏览文件 @
e7afa391
...
...
@@ -24,6 +24,12 @@ namespace phi {
template
<
typename
T
,
typename
Context
>
void
ConjKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
RealKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
ImagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
// If T is complex
template
<
typename
T
,
...
...
@@ -50,10 +56,56 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return
x
;
}
template
<
typename
T
,
typename
Context
>
void
RealKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
// If T is complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Real
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
auto
dense_out
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
MetaTensor
meta_out
(
&
dense_out
);
RealAndImagInferMeta
(
x
,
&
meta_out
);
RealKernel
<
T
>
(
dev_ctx
,
x
,
&
dense_out
);
return
dense_out
;
}
template
<
typename
T
,
typename
Context
>
void
ImagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
// If T is not complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
&&
!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Real
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
return
x
;
}
// If T is complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Imag
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
auto
dense_out
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
MetaTensor
meta_out
(
&
dense_out
);
RealAndImagInferMeta
(
x
,
&
meta_out
);
ImagKernel
<
T
>
(
dev_ctx
,
x
,
&
dense_out
);
return
dense_out
;
}
// If T is not complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
&&
!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Imag
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
return
x
;
}
}
// namespace phi
paddle/phi/kernels/funcs/diag_functor.h
浏览文件 @
e7afa391
...
...
@@ -14,6 +14,14 @@
#pragma once
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace
phi
{
namespace
funcs
{
...
...
@@ -25,5 +33,96 @@ inline int ComputeStride(int axis, phi::DDim dims) {
return
size
;
}
template
<
typename
T
,
typename
ValueType
>
struct
DiagAndFillFunctor
{
DiagAndFillFunctor
(
const
int
m
,
const
int
n
,
const
int
num_lower_diags
,
const
int
num_upper_diags
,
const
ValueType
*
scale
,
const
T
*
input
,
T
*
output
)
:
m_
(
m
),
n_
(
n
),
num_lower_diags_
(
num_lower_diags
),
num_upper_diags_
(
num_upper_diags
),
scale_
(
scale
),
input_
(
input
),
output_
(
output
)
{}
HOSTDEVICE
void
operator
()(
size_t
index
)
const
{
const
int
col
=
index
%
n_
;
const
int
row
=
(
index
/
n_
)
%
m_
;
const
int
band_start
=
(
num_lower_diags_
<
0
?
0
:
row
-
num_lower_diags_
);
const
int
band_end
=
(
num_upper_diags_
<
0
?
n_
:
row
+
num_upper_diags_
+
1
);
if
(
col
<
band_start
||
col
>=
band_end
)
{
output_
[
index
]
=
input_
[
index
];
}
else
if
(
col
==
band_end
-
1
)
{
output_
[
index
]
=
static_cast
<
T
>
(
scale_
[
index
%
m_
]);
}
else
{
output_
[
index
]
=
input_
[
index
];
}
}
private:
const
int
m_
,
n_
,
num_lower_diags_
,
num_upper_diags_
;
const
ValueType
*
scale_
;
const
T
*
input_
;
T
*
output_
;
};
template
<
typename
T
,
typename
ValueType
,
typename
Context
>
DenseTensor
DiagFill
(
const
Context
&
dev_ctx
,
const
int
m
,
const
int
n
,
const
int
num_lower_diags
,
const
int
num_upper_diags
,
const
DenseTensor
&
scale
,
const
DenseTensor
&
input
)
{
DenseTensor
out
;
out
.
Resize
(
input
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
&
out
);
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
input
.
numel
());
DiagAndFillFunctor
<
T
,
ValueType
>
diag_and_copy_functor
(
m
,
n
,
num_lower_diags
,
num_upper_diags
,
scale
.
data
<
ValueType
>
(),
input
.
data
<
T
>
(),
out
.
data
<
T
>
());
for_range
(
diag_and_copy_functor
);
return
out
;
}
template
<
typename
T
,
typename
Context
>
DenseTensor
BatchDiag
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
batch
)
{
DenseTensor
out
;
auto
*
x_data
=
x
.
data
<
phi
::
dtype
::
Real
<
T
>>
();
auto
numel
=
x
.
numel
();
out
.
Resize
(
x
.
dims
());
auto
*
out_data
=
dev_ctx
.
template
HostAlloc
<
phi
::
dtype
::
Real
<
T
>
>
(
&
out
,
static_cast
<
size_t
>
(
numel
*
sizeof
(
phi
::
dtype
::
Real
<
T
>
)));
auto
x_dims
=
x
.
dims
();
int
num_dims
=
x_dims
.
size
();
std
::
vector
<
int
>
out_shape
;
for
(
int
i
=
0
;
i
<
num_dims
-
1
;
++
i
)
{
out_shape
.
push_back
(
x
.
dims
()[
i
]);
}
out
.
Resize
(
phi
::
make_ddim
(
out_shape
));
int
order
=
x
.
dims
()[
num_dims
-
1
];
int
stride_out
=
order
*
order
;
int
stride_in
=
order
+
1
;
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
order
;
++
j
)
{
out_data
[
i
*
order
+
j
]
=
x_data
[
stride_out
*
i
+
stride_in
*
j
];
}
}
return
out
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/funcs/slice.h
0 → 100644
浏览文件 @
e7afa391
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace
phi
{
namespace
funcs
{
template
<
typename
Context
,
typename
T
,
size_t
D
>
void
EigenSliceWrapper
(
const
Context
&
dev_ctx
,
const
DenseTensor
*
in
,
const
std
::
vector
<
int
>&
start
,
const
std
::
vector
<
int
>&
end
,
DenseTensor
*
out
)
{
// Slice by call Eigen Tensor Function `.slice()`
size_t
rank
=
in
->
dims
().
size
();
PADDLE_ENFORCE_EQ
(
start
.
size
(),
rank
,
errors
::
InvalidArgument
(
"EigenSliceWrapper function start "
"argument must have the same length as input rank."
));
PADDLE_ENFORCE_EQ
(
end
.
size
(),
rank
,
errors
::
InvalidArgument
(
"EigenSliceWrapper function end "
"argument must have the same length as input rank."
));
auto
eigen_place_ptr
=
dev_ctx
.
eigen_device
();
auto
eigen_place
=
*
eigen_place_ptr
;
auto
out_t
=
phi
::
EigenTensor
<
T
,
D
>::
From
(
*
out
,
out
->
dims
());
auto
in_t
=
phi
::
EigenTensor
<
T
,
D
>::
From
(
*
in
,
in
->
dims
());
Eigen
::
DSizes
<
int
,
D
>
offsets_32bit
,
extents_32bit
;
for
(
size_t
i
=
0
;
i
<
D
;
i
++
)
{
offsets_32bit
[
i
]
=
start
[
i
];
extents_32bit
[
i
]
=
end
[
i
];
}
EigenSlice
<
std
::
decay_t
<
decltype
(
eigen_place
)
>
,
T
,
D
>::
Eval
(
eigen_place
,
phi
::
To32BitIndex
(
out_t
),
phi
::
To32BitIndex
(
in_t
),
offsets_32bit
,
extents_32bit
);
}
#define SLICE_RANK_CASE(N) \
case N: { \
EigenSliceWrapper<Context, T, N>(dev_ctx, &x, offset, extends, &ret); \
break; \
}
template
<
typename
T
,
typename
Context
>
DenseTensor
Slice
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
std
::
vector
<
int
>
axes
,
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
)
{
DenseTensor
ret
;
std
::
vector
<
int
>
new_axes
=
axes
;
std
::
vector
<
int
>
out_shape
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
size_t
rank
=
out_shape
.
size
();
PADDLE_ENFORCE_EQ
(
axes
.
size
(),
starts
.
size
(),
errors
::
InvalidArgument
(
"Slice Operator Argument Invalided"
));
PADDLE_ENFORCE_EQ
(
ends
.
size
(),
starts
.
size
(),
errors
::
InvalidArgument
(
"Slice Operator Argument Invalided"
));
for
(
unsigned
int
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
int
axis
=
axes
[
i
];
if
(
axis
<
0
)
axis
=
rank
+
axis
;
new_axes
[
i
]
=
axis
;
// change negative to positive
int
st
=
starts
[
i
];
int
ed
=
ends
[
i
];
PADDLE_ENFORCE_GT
(
ed
,
st
,
errors
::
InvalidArgument
(
"C++ Slice Operation Not Support End < Start"
));
out_shape
[
axis
]
=
ed
-
st
;
}
std
::
vector
<
int
>
offset
(
rank
),
extends
(
rank
);
for
(
size_t
i
=
0
;
i
<
rank
;
++
i
)
{
offset
[
i
]
=
0
;
extends
[
i
]
=
x
.
dims
()[
i
];
}
for
(
size_t
i
=
0
;
i
<
new_axes
.
size
();
++
i
)
{
offset
[
new_axes
[
i
]]
=
starts
[
i
];
extends
[
new_axes
[
i
]]
=
ends
[
i
]
-
starts
[
i
];
}
ret
.
Resize
(
phi
::
make_ddim
(
out_shape
));
dev_ctx
.
template
Alloc
<
T
>(
&
ret
);
switch
(
rank
)
{
SLICE_RANK_CASE
(
1
);
SLICE_RANK_CASE
(
2
);
SLICE_RANK_CASE
(
3
);
SLICE_RANK_CASE
(
4
);
SLICE_RANK_CASE
(
5
);
SLICE_RANK_CASE
(
6
);
default:
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"Invalid Rank number, "
"currently only support rank between 2~6"
));
}
}
return
ret
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/funcs/unsqueeze.h
0 → 100644
浏览文件 @
e7afa391
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace
phi
{
namespace
funcs
{
inline
const
DenseTensor
Unsqueeze
(
const
DenseTensor
&
x
,
int
axis
=
0
)
{
// don't copy data, only change the dims
DenseTensor
out
(
x
);
std
::
vector
<
int
>
out_shape
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
if
(
axis
>=
0
)
{
auto
index
=
(
out_shape
.
begin
()
+
axis
);
out_shape
.
insert
(
index
,
1
);
}
else
if
(
axis
<
0
)
{
auto
index
=
(
out_shape
.
end
()
+
axis
+
1
);
out_shape
.
insert
(
index
,
1
);
}
out
.
Resize
(
phi
::
make_ddim
(
out_shape
));
return
out
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/matmul_kernel.h
浏览文件 @
e7afa391
...
...
@@ -33,8 +33,8 @@ template <typename T, typename Context>
DenseTensor
Matmul
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
bool
transpose_x
,
bool
transpose_y
)
{
bool
transpose_x
=
false
,
bool
transpose_y
=
false
)
{
auto
dense_out
=
Empty
<
T
,
Context
>
(
dev_ctx
);
MetaTensor
meta_out
(
&
dense_out
);
MatmulInferMeta
(
x
,
y
,
transpose_x
,
transpose_y
,
&
meta_out
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录