Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e7afa391
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7afa391
编写于
3月 05, 2022
作者:
C
Chen Weihang
提交者:
GitHub
3月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Remove eig op depend for svd_helper (#40174)
* remove eig dep for svd helper * fix win failed
上级
4be5448b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
379 addition
and
44 deletion
+379
-44
paddle/fluid/operators/eig_op.h
paddle/fluid/operators/eig_op.h
+54
-38
paddle/phi/kernels/complex_kernel.h
paddle/phi/kernels/complex_kernel.h
+56
-4
paddle/phi/kernels/funcs/diag_functor.h
paddle/phi/kernels/funcs/diag_functor.h
+99
-0
paddle/phi/kernels/funcs/slice.h
paddle/phi/kernels/funcs/slice.h
+127
-0
paddle/phi/kernels/funcs/unsqueeze.h
paddle/phi/kernels/funcs/unsqueeze.h
+41
-0
paddle/phi/kernels/matmul_kernel.h
paddle/phi/kernels/matmul_kernel.h
+2
-2
未找到文件。
paddle/fluid/operators/eig_op.h
浏览文件 @
e7afa391
...
...
@@ -18,12 +18,19 @@
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#define EPSILON 1e-6
namespace
paddle
{
...
...
@@ -214,12 +221,17 @@ class EigKernel : public framework::OpKernel<T> {
ApplyEigKernel
<
DeviceContext
,
phi
::
dtype
::
Real
<
T
>>
(
*
x
,
&
real_values
,
&
real_vectors
,
context
);
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
phi
::
dtype
::
Real
<
T
>
,
Tout
>
(
context
);
auto
&
orig_dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
orig_dev_ctx
);
// 1. extract real part & imag part from real_values
Tensor
real_part
=
dito
.
Slice
(
real_values
,
{
-
1
},
{
0
},
{
order
});
Tensor
imag_part
=
dito
.
Slice
(
real_values
,
{
-
1
},
{
order
},
{
order
*
2
});
Tensor
real_part
=
phi
::
funcs
::
Slice
<
T
>
(
dev_ctx
,
real_values
,
{
-
1
},
{
0
},
{
order
});
Tensor
imag_part
=
phi
::
funcs
::
Slice
<
T
>
(
dev_ctx
,
real_values
,
{
-
1
},
{
order
},
{
order
*
2
});
// 2. construct complex values
auto
*
real_part_data
=
real_part
.
data
<
phi
::
dtype
::
Real
<
T
>>
();
...
...
@@ -233,7 +245,8 @@ class EigKernel : public framework::OpKernel<T> {
for_range
(
functor
);
// 3. construct complex vectors
Tensor
real_vector_trans
=
dito
.
Transpose
(
real_vectors
);
Tensor
real_vector_trans
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
real_vectors
);
Tensor
out_vectors_trans
;
out_vectors_trans
.
mutable_data
<
Tout
>
(
x
->
dims
(),
context
.
GetPlace
());
ConstructComplexVectors
<
phi
::
dtype
::
Real
<
T
>
,
Tout
>
(
...
...
@@ -251,45 +264,48 @@ class EigKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
out
>
template
<
typename
DeviceContext
,
typename
T
>
void
ComputeBackwardForComplexInput
(
const
Tensor
&
V
,
const
Tensor
&
L
,
const
Tensor
&
gL
,
const
Tensor
&
gV
,
T
out
*
x_grad_data
,
int
batch_count
,
int
order
,
T
*
x_grad_data
,
int
batch_count
,
int
order
,
const
framework
::
ExecutionContext
&
context
)
{
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
Tout
,
Tout
>
(
context
);
Tensor
trans_v
=
dito
.
Transpose
(
V
);
Tensor
Vh
=
dito
.
Conj
(
trans_v
);
Tensor
Lconj
=
dito
.
Conj
(
L
);
Tensor
Econj
=
dito
.
Sub
(
dito
.
Unsqueeze
(
Lconj
,
-
2
),
dito
.
Unsqueeze
(
Lconj
,
-
1
));
Tensor
VhgV
=
dito
.
Matmul
(
Vh
,
gV
);
Tensor
diag_real
=
dito
.
Real
(
VhgV
);
Tensor
diag_res
=
dito
.
BatchDiag
(
diag_real
,
batch_count
);
Tensor
diag_unsqueezed
=
dito
.
Unsqueeze
(
diag_res
,
-
2
);
auto
&
orig_dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
orig_dev_ctx
);
Tensor
trans_v
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
V
);
Tensor
Vh
=
phi
::
Conj
<
T
>
(
dev_ctx
,
trans_v
);
Tensor
Lconj
=
phi
::
Conj
<
T
>
(
dev_ctx
,
L
);
Tensor
Econj
=
phi
::
Subtract
<
T
>
(
dev_ctx
,
phi
::
funcs
::
Unsqueeze
(
Lconj
,
-
2
),
phi
::
funcs
::
Unsqueeze
(
Lconj
,
-
1
));
Tensor
VhgV
=
phi
::
Matmul
<
T
>
(
dev_ctx
,
Vh
,
gV
);
Tensor
diag_real
=
phi
::
Real
<
T
>
(
dev_ctx
,
VhgV
);
Tensor
diag_res
=
phi
::
funcs
::
BatchDiag
<
T
>
(
dev_ctx
,
diag_real
,
batch_count
);
Tensor
diag_unsqueezed
=
phi
::
funcs
::
Unsqueeze
(
diag_res
,
-
2
);
// turn diag_unsqueezed into complex
auto
numel
=
diag_unsqueezed
.
numel
();
Tensor
diag_unsqueezed_complex
;
auto
*
data_diag_un
=
diag_unsqueezed
.
data
<
phi
::
dtype
::
Real
<
T
out
>>
();
auto
*
data_diag_un_com
=
diag_unsqueezed_complex
.
mutable_data
<
T
out
>
(
auto
*
data_diag_un
=
diag_unsqueezed
.
data
<
phi
::
dtype
::
Real
<
T
>>
();
auto
*
data_diag_un_com
=
diag_unsqueezed_complex
.
mutable_data
<
T
>
(
diag_unsqueezed
.
dims
(),
context
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
out
)));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
phi
::
funcs
::
RealToComplexFunctor
<
T
out
>
functor
(
data_diag_un
,
data_diag_un_com
,
numel
);
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
platform
::
ForRange
<
DeviceContext
>
for_range
(
orig_
dev_ctx
,
numel
);
phi
::
funcs
::
RealToComplexFunctor
<
T
>
functor
(
data_diag_un
,
data_diag_un_com
,
numel
);
for_range
(
functor
);
// real tensor multiply complex tensor in broadcast manner
Tensor
res1
=
dito
.
RealMulComplex
(
V
,
diag_unsqueezed_complex
);
Tensor
res2
=
dito
.
Matmul
(
Vh
,
res1
);
Tensor
result
=
dito
.
Sub
(
VhgV
,
res2
);
Tensor
res1
=
phi
::
Multiply
<
T
>
(
dev_ctx
,
V
,
diag_unsqueezed_complex
);
Tensor
res2
=
phi
::
Matmul
<
T
>
(
dev_ctx
,
Vh
,
res1
);
Tensor
result
=
phi
::
Subtract
<
T
>
(
dev_ctx
,
VhgV
,
res2
);
result
.
mutable_data
<
Tout
>
(
V
.
dims
(),
context
.
GetPlace
());
result
=
dito
.
Div
(
result
,
Econj
);
result
=
dito
.
DiagFill
(
order
,
order
,
order
,
0
,
gL
,
result
);
Tensor
rhs
=
dito
.
Matmul
(
result
,
Vh
);
result
.
mutable_data
<
T
>
(
V
.
dims
(),
context
.
GetPlace
());
result
=
phi
::
Divide
<
T
>
(
dev_ctx
,
result
,
Econj
);
result
=
phi
::
funcs
::
DiagFill
<
T
,
T
>
(
dev_ctx
,
order
,
order
,
order
,
0
,
gL
,
result
);
Tensor
rhs
=
phi
::
Matmul
<
T
>
(
dev_ctx
,
result
,
Vh
);
// solve linear system
// solve(Vh, rhs, out, m, k)
...
...
@@ -298,10 +314,10 @@ void ComputeBackwardForComplexInput(
// x_grad: out
int
m
=
Vh
.
dims
()[
Vh
.
dims
().
size
()
-
1
];
int
k
=
rhs
.
dims
()[
rhs
.
dims
().
size
()
-
1
];
auto
*
matrix_data
=
Vh
.
data
<
T
out
>
();
auto
*
rhs_data
=
rhs
.
data
<
T
out
>
();
math
::
SolveLinearSystem
<
T
out
>
(
matrix_data
,
rhs_data
,
x_grad_data
,
m
,
k
,
batch_count
);
auto
*
matrix_data
=
Vh
.
data
<
T
>
();
auto
*
rhs_data
=
rhs
.
data
<
T
>
();
math
::
SolveLinearSystem
<
T
>
(
matrix_data
,
rhs_data
,
x_grad_data
,
m
,
k
,
batch_count
);
}
template
<
typename
DeviceContext
,
typename
T
,
typename
Tout
>
...
...
paddle/phi/kernels/complex_kernel.h
浏览文件 @
e7afa391
...
...
@@ -24,6 +24,12 @@ namespace phi {
template
<
typename
T
,
typename
Context
>
void
ConjKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
RealKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
ImagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
// If T is complex
template
<
typename
T
,
...
...
@@ -50,10 +56,56 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return
x
;
}
template
<
typename
T
,
typename
Context
>
void
RealKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
// If T is complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Real
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
auto
dense_out
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
MetaTensor
meta_out
(
&
dense_out
);
RealAndImagInferMeta
(
x
,
&
meta_out
);
RealKernel
<
T
>
(
dev_ctx
,
x
,
&
dense_out
);
return
dense_out
;
}
template
<
typename
T
,
typename
Context
>
void
ImagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
// If T is not complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
&&
!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Real
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
return
x
;
}
// If T is complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
||
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Imag
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
auto
dense_out
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
);
MetaTensor
meta_out
(
&
dense_out
);
RealAndImagInferMeta
(
x
,
&
meta_out
);
ImagKernel
<
T
>
(
dev_ctx
,
x
,
&
dense_out
);
return
dense_out
;
}
// If T is not complex
template
<
typename
T
,
typename
Context
,
std
::
enable_if_t
<!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
float
>
>::
value
&&
!
std
::
is_same
<
T
,
phi
::
dtype
::
complex
<
double
>>::
value
,
bool
>
=
true
>
DenseTensor
Imag
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
return
x
;
}
}
// namespace phi
paddle/phi/kernels/funcs/diag_functor.h
浏览文件 @
e7afa391
...
...
@@ -14,6 +14,14 @@
#pragma once
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace
phi
{
namespace
funcs
{
...
...
@@ -25,5 +33,96 @@ inline int ComputeStride(int axis, phi::DDim dims) {
return
size
;
}
template
<
typename
T
,
typename
ValueType
>
struct
DiagAndFillFunctor
{
DiagAndFillFunctor
(
const
int
m
,
const
int
n
,
const
int
num_lower_diags
,
const
int
num_upper_diags
,
const
ValueType
*
scale
,
const
T
*
input
,
T
*
output
)
:
m_
(
m
),
n_
(
n
),
num_lower_diags_
(
num_lower_diags
),
num_upper_diags_
(
num_upper_diags
),
scale_
(
scale
),
input_
(
input
),
output_
(
output
)
{}
HOSTDEVICE
void
operator
()(
size_t
index
)
const
{
const
int
col
=
index
%
n_
;
const
int
row
=
(
index
/
n_
)
%
m_
;
const
int
band_start
=
(
num_lower_diags_
<
0
?
0
:
row
-
num_lower_diags_
);
const
int
band_end
=
(
num_upper_diags_
<
0
?
n_
:
row
+
num_upper_diags_
+
1
);
if
(
col
<
band_start
||
col
>=
band_end
)
{
output_
[
index
]
=
input_
[
index
];
}
else
if
(
col
==
band_end
-
1
)
{
output_
[
index
]
=
static_cast
<
T
>
(
scale_
[
index
%
m_
]);
}
else
{
output_
[
index
]
=
input_
[
index
];
}
}
private:
const
int
m_
,
n_
,
num_lower_diags_
,
num_upper_diags_
;
const
ValueType
*
scale_
;
const
T
*
input_
;
T
*
output_
;
};
template
<
typename
T
,
typename
ValueType
,
typename
Context
>
DenseTensor
DiagFill
(
const
Context
&
dev_ctx
,
const
int
m
,
const
int
n
,
const
int
num_lower_diags
,
const
int
num_upper_diags
,
const
DenseTensor
&
scale
,
const
DenseTensor
&
input
)
{
DenseTensor
out
;
out
.
Resize
(
input
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
&
out
);
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
input
.
numel
());
DiagAndFillFunctor
<
T
,
ValueType
>
diag_and_copy_functor
(
m
,
n
,
num_lower_diags
,
num_upper_diags
,
scale
.
data
<
ValueType
>
(),
input
.
data
<
T
>
(),
out
.
data
<
T
>
());
for_range
(
diag_and_copy_functor
);
return
out
;
}
template
<
typename
T
,
typename
Context
>
DenseTensor
BatchDiag
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
batch
)
{
DenseTensor
out
;
auto
*
x_data
=
x
.
data
<
phi
::
dtype
::
Real
<
T
>>
();
auto
numel
=
x
.
numel
();
out
.
Resize
(
x
.
dims
());
auto
*
out_data
=
dev_ctx
.
template
HostAlloc
<
phi
::
dtype
::
Real
<
T
>
>
(
&
out
,
static_cast
<
size_t
>
(
numel
*
sizeof
(
phi
::
dtype
::
Real
<
T
>
)));
auto
x_dims
=
x
.
dims
();
int
num_dims
=
x_dims
.
size
();
std
::
vector
<
int
>
out_shape
;
for
(
int
i
=
0
;
i
<
num_dims
-
1
;
++
i
)
{
out_shape
.
push_back
(
x
.
dims
()[
i
]);
}
out
.
Resize
(
phi
::
make_ddim
(
out_shape
));
int
order
=
x
.
dims
()[
num_dims
-
1
];
int
stride_out
=
order
*
order
;
int
stride_in
=
order
+
1
;
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
order
;
++
j
)
{
out_data
[
i
*
order
+
j
]
=
x_data
[
stride_out
*
i
+
stride_in
*
j
];
}
}
return
out
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/funcs/slice.h
0 → 100644
浏览文件 @
e7afa391
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace
phi
{
namespace
funcs
{
template
<
typename
Context
,
typename
T
,
size_t
D
>
void
EigenSliceWrapper
(
const
Context
&
dev_ctx
,
const
DenseTensor
*
in
,
const
std
::
vector
<
int
>&
start
,
const
std
::
vector
<
int
>&
end
,
DenseTensor
*
out
)
{
// Slice by call Eigen Tensor Function `.slice()`
size_t
rank
=
in
->
dims
().
size
();
PADDLE_ENFORCE_EQ
(
start
.
size
(),
rank
,
errors
::
InvalidArgument
(
"EigenSliceWrapper function start "
"argument must have the same length as input rank."
));
PADDLE_ENFORCE_EQ
(
end
.
size
(),
rank
,
errors
::
InvalidArgument
(
"EigenSliceWrapper function end "
"argument must have the same length as input rank."
));
auto
eigen_place_ptr
=
dev_ctx
.
eigen_device
();
auto
eigen_place
=
*
eigen_place_ptr
;
auto
out_t
=
phi
::
EigenTensor
<
T
,
D
>::
From
(
*
out
,
out
->
dims
());
auto
in_t
=
phi
::
EigenTensor
<
T
,
D
>::
From
(
*
in
,
in
->
dims
());
Eigen
::
DSizes
<
int
,
D
>
offsets_32bit
,
extents_32bit
;
for
(
size_t
i
=
0
;
i
<
D
;
i
++
)
{
offsets_32bit
[
i
]
=
start
[
i
];
extents_32bit
[
i
]
=
end
[
i
];
}
EigenSlice
<
std
::
decay_t
<
decltype
(
eigen_place
)
>
,
T
,
D
>::
Eval
(
eigen_place
,
phi
::
To32BitIndex
(
out_t
),
phi
::
To32BitIndex
(
in_t
),
offsets_32bit
,
extents_32bit
);
}
#define SLICE_RANK_CASE(N) \
case N: { \
EigenSliceWrapper<Context, T, N>(dev_ctx, &x, offset, extends, &ret); \
break; \
}
template
<
typename
T
,
typename
Context
>
DenseTensor
Slice
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
std
::
vector
<
int
>
axes
,
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
)
{
DenseTensor
ret
;
std
::
vector
<
int
>
new_axes
=
axes
;
std
::
vector
<
int
>
out_shape
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
size_t
rank
=
out_shape
.
size
();
PADDLE_ENFORCE_EQ
(
axes
.
size
(),
starts
.
size
(),
errors
::
InvalidArgument
(
"Slice Operator Argument Invalided"
));
PADDLE_ENFORCE_EQ
(
ends
.
size
(),
starts
.
size
(),
errors
::
InvalidArgument
(
"Slice Operator Argument Invalided"
));
for
(
unsigned
int
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
int
axis
=
axes
[
i
];
if
(
axis
<
0
)
axis
=
rank
+
axis
;
new_axes
[
i
]
=
axis
;
// change negative to positive
int
st
=
starts
[
i
];
int
ed
=
ends
[
i
];
PADDLE_ENFORCE_GT
(
ed
,
st
,
errors
::
InvalidArgument
(
"C++ Slice Operation Not Support End < Start"
));
out_shape
[
axis
]
=
ed
-
st
;
}
std
::
vector
<
int
>
offset
(
rank
),
extends
(
rank
);
for
(
size_t
i
=
0
;
i
<
rank
;
++
i
)
{
offset
[
i
]
=
0
;
extends
[
i
]
=
x
.
dims
()[
i
];
}
for
(
size_t
i
=
0
;
i
<
new_axes
.
size
();
++
i
)
{
offset
[
new_axes
[
i
]]
=
starts
[
i
];
extends
[
new_axes
[
i
]]
=
ends
[
i
]
-
starts
[
i
];
}
ret
.
Resize
(
phi
::
make_ddim
(
out_shape
));
dev_ctx
.
template
Alloc
<
T
>(
&
ret
);
switch
(
rank
)
{
SLICE_RANK_CASE
(
1
);
SLICE_RANK_CASE
(
2
);
SLICE_RANK_CASE
(
3
);
SLICE_RANK_CASE
(
4
);
SLICE_RANK_CASE
(
5
);
SLICE_RANK_CASE
(
6
);
default:
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"Invalid Rank number, "
"currently only support rank between 2~6"
));
}
}
return
ret
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/funcs/unsqueeze.h
0 → 100644
浏览文件 @
e7afa391
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace
phi
{
namespace
funcs
{
inline
const
DenseTensor
Unsqueeze
(
const
DenseTensor
&
x
,
int
axis
=
0
)
{
// don't copy data, only change the dims
DenseTensor
out
(
x
);
std
::
vector
<
int
>
out_shape
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
if
(
axis
>=
0
)
{
auto
index
=
(
out_shape
.
begin
()
+
axis
);
out_shape
.
insert
(
index
,
1
);
}
else
if
(
axis
<
0
)
{
auto
index
=
(
out_shape
.
end
()
+
axis
+
1
);
out_shape
.
insert
(
index
,
1
);
}
out
.
Resize
(
phi
::
make_ddim
(
out_shape
));
return
out
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/matmul_kernel.h
浏览文件 @
e7afa391
...
...
@@ -33,8 +33,8 @@ template <typename T, typename Context>
DenseTensor
Matmul
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
bool
transpose_x
,
bool
transpose_y
)
{
bool
transpose_x
=
false
,
bool
transpose_y
=
false
)
{
auto
dense_out
=
Empty
<
T
,
Context
>
(
dev_ctx
);
MetaTensor
meta_out
(
&
dense_out
);
MatmulInferMeta
(
x
,
y
,
transpose_x
,
transpose_y
,
&
meta_out
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录