Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
783c4aba
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
783c4aba
编写于
2月 25, 2022
作者:
L
Linjie Chen
提交者:
GitHub
2月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move diag_v2 to phi (#39914)
上级
2533cac6
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
340 addition
and
251 deletion
+340
-251
paddle/fluid/operators/diag_v2_op.cc
paddle/fluid/operators/diag_v2_op.cc
+8
-88
paddle/fluid/operators/diag_v2_op.cu
paddle/fluid/operators/diag_v2_op.cu
+0
-128
paddle/fluid/operators/diag_v2_op.h
paddle/fluid/operators/diag_v2_op.h
+0
-34
paddle/phi/core/compat/op_utils.h
paddle/phi/core/compat/op_utils.h
+2
-1
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+1
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+40
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+5
-0
paddle/phi/kernels/cpu/diag_kernel.cc
paddle/phi/kernels/cpu/diag_kernel.cc
+66
-0
paddle/phi/kernels/diag_kernel.h
paddle/phi/kernels/diag_kernel.h
+28
-0
paddle/phi/kernels/funcs/diag_functor.h
paddle/phi/kernels/funcs/diag_functor.h
+29
-0
paddle/phi/kernels/gpu/diag_kernel.cu
paddle/phi/kernels/gpu/diag_kernel.cu
+134
-0
paddle/phi/ops/compat/diag_sig.cc
paddle/phi/ops/compat/diag_sig.cc
+27
-0
未找到文件。
paddle/fluid/operators/diag_v2_op.cc
浏览文件 @
783c4aba
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/diag_v2_op.h"
#include <algorithm>
#include <algorithm>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -23,44 +25,6 @@ namespace operators {
...
@@ -23,44 +25,6 @@ namespace operators {
class
DiagV2Op
:
public
framework
::
OperatorWithKernel
{
class
DiagV2Op
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"diag_v2"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"diag_v2"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
offset
=
ctx
->
Attrs
().
Get
<
int
>
(
"offset"
);
if
(
x_dims
.
size
()
==
1UL
)
{
int64_t
size_
=
x_dims
[
0
]
+
std
::
abs
(
offset
);
ctx
->
SetOutputDim
(
"Out"
,
{
size_
,
size_
});
}
else
if
(
x_dims
.
size
()
==
2UL
)
{
int64_t
size_
=
0
;
if
(
offset
>=
0
)
{
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if
(
x_dims
[
0
]
<
x_dims
[
1
]
-
offset
)
{
size_
=
x_dims
[
0
];
}
else
{
size_
=
x_dims
[
1
]
-
offset
;
}
}
else
{
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if
(
x_dims
[
0
]
+
offset
<
x_dims
[
1
])
{
size_
=
x_dims
[
0
]
+
offset
;
}
else
{
size_
=
x_dims
[
1
];
}
}
ctx
->
SetOutputDim
(
"Out"
,
{
size_
});
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d."
,
x_dims
.
size
()));
}
}
};
};
class
DiagV2OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
DiagV2OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -94,59 +58,15 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -94,59 +58,15 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
DiagV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
X
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
x_data
=
X
->
data
<
T
>
();
auto
x_dims
=
X
->
dims
();
int
offset
=
context
.
Attr
<
int
>
(
"offset"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_dims
=
out
->
dims
();
int64_t
i
;
if
(
x_dims
.
size
()
==
1
)
{
float
padding_value
=
context
.
Attr
<
float
>
(
"padding_value"
);
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
set_padding_value
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_padding_value
(
dev_ctx
,
out
,
static_cast
<
T
>
(
padding_value
));
auto
x_length
=
x_dims
[
0
];
const
int
&
x_stride
=
ComputeStride
(
0
,
x_dims
);
auto
out_stride_0
=
ComputeStride
(
0
,
out_dims
);
auto
out_stride_1
=
ComputeStride
(
1
,
out_dims
);
out_data
+=
(
offset
>=
0
?
offset
*
out_stride_1
:
-
offset
*
out_stride_0
);
for
(
i
=
0
;
i
<
x_length
;
i
++
)
{
out_data
[
i
*
(
out_stride_0
+
out_stride_1
)]
=
x_data
[
i
*
x_stride
];
}
}
else
{
auto
out_length
=
out_dims
[
0
];
const
int
&
x_stride_0
=
ComputeStride
(
0
,
x_dims
);
const
int
&
x_stride_1
=
ComputeStride
(
1
,
x_dims
);
auto
out_stride_0
=
ComputeStride
(
0
,
out_dims
);
x_data
+=
(
offset
>=
0
?
offset
*
x_stride_1
:
-
offset
*
x_stride_0
);
for
(
i
=
0
;
i
<
out_length
;
i
++
)
{
out_data
[
i
*
out_stride_0
]
=
x_data
[
i
*
(
x_stride_0
+
x_stride_1
)];
}
}
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
DELCARE_INFER_SHAPE_FUNCTOR
(
diag_v2
,
DiagInferShapeFunctor
,
PT_INFER_META
(
phi
::
DiagInferMeta
));
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
diag_v2
,
ops
::
DiagV2Op
,
ops
::
DiagV2OpMaker
,
diag_v2
,
ops
::
DiagV2Op
,
ops
::
DiagV2OpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
REGISTER_OP_CPU_KERNEL
(
DiagInferShapeFunctor
);
diag_v2
,
ops
::
DiagV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
DiagV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DiagV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
DiagV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/diag_v2_op.cu
已删除
100644 → 0
浏览文件 @
2533cac6
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <tuple>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_v2_op.h"
namespace
paddle
{
namespace
operators
{
// Extract the diagonal of a matrix 'x' to a vector 'out'.
template
<
typename
T
>
__global__
void
ExtractDiagonalKernel
(
T
*
out
,
const
T
*
x
,
std
::
ptrdiff_t
start
,
std
::
ptrdiff_t
size
,
const
std
::
ptrdiff_t
sumStride
,
const
std
::
ptrdiff_t
outStride
)
{
for
(
std
::
ptrdiff_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
std
::
ptrdiff_t
xOffset
=
start
+
sumStride
*
idx
;
out
[
outStride
*
idx
]
=
x
[
xOffset
];
}
}
// Paste a vector 'x' to the diagonal of a matrix 'out'
template
<
typename
T
>
__global__
void
PasteDiagonalKernel
(
T
*
out
,
const
T
*
x
,
std
::
ptrdiff_t
start
,
std
::
ptrdiff_t
x_length
,
const
std
::
ptrdiff_t
sumStride
,
const
std
::
ptrdiff_t
xStride
)
{
for
(
std
::
ptrdiff_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
x_length
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
std
::
ptrdiff_t
outOffset
=
start
+
sumStride
*
idx
;
out
[
outOffset
]
=
x
[
xStride
*
idx
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
DiagV2CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
X
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
x_data
=
X
->
data
<
T
>
();
auto
x_dims
=
X
->
dims
();
int
offset
=
context
.
Attr
<
int
>
(
"offset"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_dims
=
out
->
dims
();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
GetBlockGridSize
=
[
&
dev_ctx
](
int64_t
size
)
{
const
int64_t
block_size
=
std
::
min
(
size
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
size
+
block_size
-
1
)
/
block_size
);
return
std
::
tuple
<
int64_t
,
int64_t
>
{
block_size
,
grid_size
};
};
if
(
x_dims
.
size
()
==
1
)
{
float
padding_value
=
context
.
Attr
<
float
>
(
"padding_value"
);
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
set_padding_value
;
set_padding_value
(
dev_ctx
,
out
,
static_cast
<
T
>
(
padding_value
));
auto
x_length
=
x_dims
[
0
];
auto
size
=
(
offset
>
0
)
?
x_length
+
offset
:
x_length
-
offset
;
const
int
&
x_stride
=
ComputeStride
(
0
,
x_dims
);
if
(
size
>
0
)
{
const
auto
&
out_stride_0
=
ComputeStride
(
0
,
out_dims
);
const
auto
&
out_stride_1
=
ComputeStride
(
1
,
out_dims
);
auto
start
=
(
offset
>=
0
?
offset
*
out_stride_1
:
-
offset
*
out_stride_0
);
std
::
tuple
<
int64_t
,
int64_t
>
block_grid_size
=
GetBlockGridSize
(
size
);
PasteDiagonalKernel
<
T
><<<
std
::
get
<
1
>
(
block_grid_size
),
std
::
get
<
0
>
(
block_grid_size
),
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
x_data
,
start
,
x_length
,
out_stride_0
+
out_stride_1
,
x_stride
);
}
}
else
{
const
int
&
x_stride_0
=
ComputeStride
(
0
,
x_dims
);
const
int
&
x_stride_1
=
ComputeStride
(
1
,
x_dims
);
int64_t
size
;
if
(
offset
>
0
)
{
size
=
std
::
min
(
x_dims
[
0
],
x_dims
[
1
]
-
offset
);
}
else
{
size
=
std
::
min
(
x_dims
[
0
]
+
offset
,
x_dims
[
1
]);
}
if
(
size
>
0
)
{
auto
start
=
(
offset
>=
0
?
offset
*
x_stride_1
:
-
offset
*
x_stride_0
);
const
auto
&
out_stride_0
=
ComputeStride
(
0
,
out_dims
);
std
::
tuple
<
int64_t
,
int64_t
>
block_grid_size
=
GetBlockGridSize
(
size
);
ExtractDiagonalKernel
<
T
><<<
std
::
get
<
1
>
(
block_grid_size
),
std
::
get
<
0
>
(
block_grid_size
),
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
x_data
,
start
,
size
,
x_stride_0
+
x_stride_1
,
out_stride_0
);
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
diag_v2
,
ops
::
DiagV2CUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
DiagV2CUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
DiagV2CUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
DiagV2CUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/diag_v2_op.h
已删除
100644 → 0
浏览文件 @
2533cac6
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
using
DDim
=
framework
::
DDim
;
static
inline
int
ComputeStride
(
int
axis
,
DDim
dims
)
{
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
}
return
size
;
}
}
// namespace operators
}
// namespace paddle
paddle/phi/core/compat/op_utils.h
浏览文件 @
783c4aba
...
@@ -37,7 +37,8 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
...
@@ -37,7 +37,8 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly.
* They are marked here uniformly.
*/
*/
const
std
::
unordered_set
<
std
::
string
>
deprecated_op_names
({
"flatten"
,
const
std
::
unordered_set
<
std
::
string
>
deprecated_op_names
({
"diag"
,
"flatten"
,
"flatten_grad"
,
"flatten_grad"
,
"matmul"
,
"matmul"
,
"matmul_grad"
,
"matmul_grad"
,
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
783c4aba
...
@@ -310,6 +310,7 @@ void BCELossInferMeta(const MetaTensor& input,
...
@@ -310,6 +310,7 @@ void BCELossInferMeta(const MetaTensor& input,
}
}
out
->
set_dims
(
input_dims
);
out
->
set_dims
(
input_dims
);
out
->
set_dtype
(
input
.
dtype
());
out
->
share_lod
(
input
);
out
->
share_lod
(
input
);
}
}
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
783c4aba
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/unary.h"
#include <algorithm>
#include <set>
#include <set>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
...
@@ -715,6 +716,45 @@ void UnfoldInferMeta(const MetaTensor& x,
...
@@ -715,6 +716,45 @@ void UnfoldInferMeta(const MetaTensor& x,
out
->
set_dims
(
phi
::
make_ddim
(
out_dims
));
out
->
set_dims
(
phi
::
make_ddim
(
out_dims
));
}
}
void
DiagInferMeta
(
const
MetaTensor
&
x
,
int
offset
,
float
padding_value
,
MetaTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
if
(
x_dims
.
size
()
==
1UL
)
{
int64_t
size_
=
x_dims
[
0
]
+
std
::
abs
(
offset
);
out
->
set_dims
({
size_
,
size_
});
out
->
set_dtype
(
x
.
dtype
());
}
else
if
(
x_dims
.
size
()
==
2UL
)
{
int64_t
size_
=
0
;
if
(
offset
>=
0
)
{
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if
(
x_dims
[
0
]
<
x_dims
[
1
]
-
offset
)
{
size_
=
x_dims
[
0
];
}
else
{
size_
=
x_dims
[
1
]
-
offset
;
}
}
else
{
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if
(
x_dims
[
0
]
+
offset
<
x_dims
[
1
])
{
size_
=
x_dims
[
0
]
+
offset
;
}
else
{
size_
=
x_dims
[
1
];
}
}
out
->
set_dims
({
size_
});
out
->
set_dtype
(
x
.
dtype
());
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d."
,
x_dims
.
size
()));
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_INFER_META_FN
(
copy_to
,
phi
::
CopyToInferMeta
);
PD_REGISTER_INFER_META_FN
(
copy_to
,
phi
::
CopyToInferMeta
);
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
783c4aba
...
@@ -104,4 +104,9 @@ void UnfoldInferMeta(const MetaTensor& x,
...
@@ -104,4 +104,9 @@ void UnfoldInferMeta(const MetaTensor& x,
MetaTensor
*
out
,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
MetaConfig
config
=
MetaConfig
());
void
DiagInferMeta
(
const
MetaTensor
&
x
,
int
offset
,
float
padding_value
,
MetaTensor
*
out
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/cpu/diag_kernel.cc
0 → 100644
浏览文件 @
783c4aba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DiagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
offset
,
float
padding_value
,
DenseTensor
*
out
)
{
auto
*
x_data
=
x
.
data
<
T
>
();
auto
x_dims
=
x
.
dims
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
out_dims
=
out
->
dims
();
int64_t
i
;
if
(
x_dims
.
size
()
==
1
)
{
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_padding_value
;
set_padding_value
(
dev_ctx
,
out
,
static_cast
<
T
>
(
padding_value
));
auto
x_length
=
x_dims
[
0
];
const
int
&
x_stride
=
phi
::
funcs
::
ComputeStride
(
0
,
x_dims
);
auto
out_stride_0
=
phi
::
funcs
::
ComputeStride
(
0
,
out_dims
);
auto
out_stride_1
=
phi
::
funcs
::
ComputeStride
(
1
,
out_dims
);
out_data
+=
(
offset
>=
0
?
offset
*
out_stride_1
:
-
offset
*
out_stride_0
);
for
(
i
=
0
;
i
<
x_length
;
i
++
)
{
out_data
[
i
*
(
out_stride_0
+
out_stride_1
)]
=
x_data
[
i
*
x_stride
];
}
}
else
{
auto
out_length
=
out_dims
[
0
];
const
int
&
x_stride_0
=
phi
::
funcs
::
ComputeStride
(
0
,
x_dims
);
const
int
&
x_stride_1
=
phi
::
funcs
::
ComputeStride
(
1
,
x_dims
);
auto
out_stride_0
=
phi
::
funcs
::
ComputeStride
(
0
,
out_dims
);
x_data
+=
(
offset
>=
0
?
offset
*
x_stride_1
:
-
offset
*
x_stride_0
);
for
(
i
=
0
;
i
<
out_length
;
i
++
)
{
out_data
[
i
*
out_stride_0
]
=
x_data
[
i
*
(
x_stride_0
+
x_stride_1
)];
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
diag
,
CPU
,
ALL_LAYOUT
,
phi
::
DiagKernel
,
int
,
float
,
double
,
int64_t
)
{}
paddle/phi/kernels/diag_kernel.h
0 → 100644
浏览文件 @
783c4aba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DiagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
offset
,
float
padding_value
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/funcs/diag_functor.h
0 → 100644
浏览文件 @
783c4aba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
phi
{
namespace
funcs
{
inline
int
ComputeStride
(
int
axis
,
phi
::
DDim
dims
)
{
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
size
*=
dims
[
i
];
}
return
size
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/diag_kernel.cu
0 → 100644
浏览文件 @
783c4aba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_kernel.h"
#include <algorithm>
#include <tuple>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
// Extract the diagonal of a matrix 'x' to a vector 'out'.
template
<
typename
T
>
__global__
void
ExtractDiagonalKernel
(
T
*
out
,
const
T
*
x
,
std
::
ptrdiff_t
start
,
std
::
ptrdiff_t
size
,
const
std
::
ptrdiff_t
sumStride
,
const
std
::
ptrdiff_t
outStride
)
{
for
(
std
::
ptrdiff_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
std
::
ptrdiff_t
xOffset
=
start
+
sumStride
*
idx
;
out
[
outStride
*
idx
]
=
x
[
xOffset
];
}
}
// Paste a vector 'x' to the diagonal of a matrix 'out'
template
<
typename
T
>
__global__
void
PasteDiagonalKernel
(
T
*
out
,
const
T
*
x
,
std
::
ptrdiff_t
start
,
std
::
ptrdiff_t
x_length
,
const
std
::
ptrdiff_t
sumStride
,
const
std
::
ptrdiff_t
xStride
)
{
for
(
std
::
ptrdiff_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
x_length
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
std
::
ptrdiff_t
outOffset
=
start
+
sumStride
*
idx
;
out
[
outOffset
]
=
x
[
xStride
*
idx
];
}
}
template
<
typename
T
,
typename
Context
>
void
DiagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
offset
,
float
padding_value
,
DenseTensor
*
out
)
{
auto
*
x_data
=
x
.
data
<
T
>
();
auto
x_dims
=
x
.
dims
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
out_dims
=
out
->
dims
();
auto
GetBlockGridSize
=
[
&
dev_ctx
](
int64_t
size
)
{
const
int64_t
block_size
=
std
::
min
(
size
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
size
+
block_size
-
1
)
/
block_size
);
return
std
::
tuple
<
int64_t
,
int64_t
>
{
block_size
,
grid_size
};
};
if
(
x_dims
.
size
()
==
1
)
{
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_padding_value
;
set_padding_value
(
dev_ctx
,
out
,
static_cast
<
T
>
(
padding_value
));
auto
x_length
=
x_dims
[
0
];
auto
size
=
(
offset
>
0
)
?
x_length
+
offset
:
x_length
-
offset
;
const
int
&
x_stride
=
phi
::
funcs
::
ComputeStride
(
0
,
x_dims
);
if
(
size
>
0
)
{
const
auto
&
out_stride_0
=
phi
::
funcs
::
ComputeStride
(
0
,
out_dims
);
const
auto
&
out_stride_1
=
phi
::
funcs
::
ComputeStride
(
1
,
out_dims
);
auto
start
=
(
offset
>=
0
?
offset
*
out_stride_1
:
-
offset
*
out_stride_0
);
std
::
tuple
<
int64_t
,
int64_t
>
block_grid_size
=
GetBlockGridSize
(
size
);
PasteDiagonalKernel
<
T
><<<
std
::
get
<
1
>
(
block_grid_size
),
std
::
get
<
0
>
(
block_grid_size
),
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
x_data
,
start
,
x_length
,
out_stride_0
+
out_stride_1
,
x_stride
);
}
}
else
{
const
int
&
x_stride_0
=
phi
::
funcs
::
ComputeStride
(
0
,
x_dims
);
const
int
&
x_stride_1
=
phi
::
funcs
::
ComputeStride
(
1
,
x_dims
);
int64_t
size
;
if
(
offset
>
0
)
{
size
=
std
::
min
(
x_dims
[
0
],
x_dims
[
1
]
-
offset
);
}
else
{
size
=
std
::
min
(
x_dims
[
0
]
+
offset
,
x_dims
[
1
]);
}
if
(
size
>
0
)
{
auto
start
=
(
offset
>=
0
?
offset
*
x_stride_1
:
-
offset
*
x_stride_0
);
const
auto
&
out_stride_0
=
phi
::
funcs
::
ComputeStride
(
0
,
out_dims
);
std
::
tuple
<
int64_t
,
int64_t
>
block_grid_size
=
GetBlockGridSize
(
size
);
ExtractDiagonalKernel
<
T
><<<
std
::
get
<
1
>
(
block_grid_size
),
std
::
get
<
0
>
(
block_grid_size
),
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
x_data
,
start
,
size
,
x_stride_0
+
x_stride_1
,
out_stride_0
);
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
diag
,
GPU
,
ALL_LAYOUT
,
phi
::
DiagKernel
,
int
,
int64_t
,
float
,
double
)
{}
paddle/phi/ops/compat/diag_sig.cc
0 → 100644
浏览文件 @
783c4aba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
DiagOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"diag"
,
{
"X"
},
{
"offset"
,
"padding_value"
},
{
"Out"
});
}
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
diag_v2
,
diag
);
PD_REGISTER_ARG_MAPPING_FN
(
diag_v2
,
phi
::
DiagOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录