Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41f11d29
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41f11d29
编写于
7月 20, 2022
作者:
Z
Zhong Hui
提交者:
GitHub
7月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PHI] move diag_embed op to phi. (#44408)
* move diag_embed to phi.
上级
889bdde3
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
310 addition
and
89 deletion
+310
-89
paddle/fluid/operators/diag_embed_op.cc
paddle/fluid/operators/diag_embed_op.cc
+10
-83
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+8
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+63
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+3
-0
paddle/phi/kernels/cpu/diag_embed_kernel.cc
paddle/phi/kernels/cpu/diag_embed_kernel.cc
+28
-0
paddle/phi/kernels/diag_embed_kernel.h
paddle/phi/kernels/diag_embed_kernel.h
+29
-0
paddle/phi/kernels/gpu/diag_embed_kernel.cu
paddle/phi/kernels/gpu/diag_embed_kernel.cu
+28
-0
paddle/phi/kernels/impl/diag_embed_impl.h
paddle/phi/kernels/impl/diag_embed_impl.h
+129
-0
python/paddle/fluid/tests/unittests/test_diag_embed.py
python/paddle/fluid/tests/unittests/test_diag_embed.py
+2
-1
python/paddle/nn/functional/extension.py
python/paddle/nn/functional/extension.py
+10
-5
未找到文件。
paddle/fluid/operators/diag_embed_op.cc
浏览文件 @
41f11d29
...
...
@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/diag_embed_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -20,81 +23,6 @@ namespace operators {
class
DiagEmbedOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
platform
::
errors
::
NotFound
(
"Input of DiagEmbedOp is not found."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
NotFound
(
"Output of DiagEmbedOp is not found."
));
int
offset
=
ctx
->
Attrs
().
Get
<
int
>
(
"offset"
);
int
dim1
=
ctx
->
Attrs
().
Get
<
int
>
(
"dim1"
);
int
dim2
=
ctx
->
Attrs
().
Get
<
int
>
(
"dim2"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_GE
(
dim1
,
-
(
x_dims
.
size
()
+
1
),
platform
::
errors
::
OutOfRange
(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim1
));
PADDLE_ENFORCE_LE
(
dim1
,
x_dims
.
size
(),
platform
::
errors
::
OutOfRange
(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim1
));
PADDLE_ENFORCE_GE
(
dim2
,
-
(
x_dims
.
size
()
+
1
),
platform
::
errors
::
OutOfRange
(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim2
));
PADDLE_ENFORCE_LE
(
dim2
,
x_dims
.
size
(),
platform
::
errors
::
OutOfRange
(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim2
));
int
dim1_
=
dim1
<
0
?
x_dims
.
size
()
+
dim1
+
1
:
dim1
;
int
dim2_
=
dim2
<
0
?
x_dims
.
size
()
+
dim2
+
1
:
dim2
;
int
offset_
=
std
::
abs
(
offset
);
PADDLE_ENFORCE_NE
(
dim1_
,
dim2_
,
platform
::
errors
::
InvalidArgument
(
"diagonal dimensions should not be identical "
"%ld vs %ld."
,
dim1
,
dim2
));
int
new_dim_len
=
offset_
+
x_dims
[
x_dims
.
size
()
-
1
];
auto
sizes
=
vectorize
(
x_dims
);
sizes
.
pop_back
();
sizes
.
insert
(
sizes
.
begin
()
+
std
::
min
(
dim1_
,
dim2_
),
new_dim_len
);
sizes
.
insert
(
sizes
.
begin
()
+
std
::
max
(
dim1_
,
dim2_
),
new_dim_len
);
ctx
->
SetOutputDim
(
"Out"
,
phi
::
make_ddim
(
sizes
));
}
};
class
DiagEmbedOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -131,15 +59,14 @@ class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
platform
=
paddle
::
platform
;
DECLARE_INFER_SHAPE_FUNCTOR
(
diag_embed
,
DiagEmbedInferShapeFunctor
,
PD_INFER_META
(
phi
::
DiagEmbedInferMeta
));
REGISTER_OPERATOR
(
diag_embed
,
ops
::
DiagEmbedOp
,
ops
::
DiagEmbedOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
diag_embed
,
ops
::
DiagEmbedKernel
<
phi
::
CPUContext
,
int
>
,
ops
::
DiagEmbedKernel
<
phi
::
CPUContext
,
float
>
,
ops
::
DiagEmbedKernel
<
phi
::
CPUContext
,
double
>
,
ops
::
DiagEmbedKernel
<
phi
::
CPUContext
,
int64_t
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
DiagEmbedInferShapeFunctor
);
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
41f11d29
...
...
@@ -524,6 +524,14 @@
func
:
determinant
backward
:
det_grad
-
api
:
diag_embed
args
:
(Tensor x, int offset, int dim1, int dim2)
output
:
Tensor
infer_meta
:
func
:
DiagEmbedInferMeta
kernel
:
func
:
diag_embed
-
api
:
divide
args
:
(Tensor x, Tensor y)
output
:
Tensor
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
41f11d29
...
...
@@ -288,6 +288,69 @@ void CumInferMeta(const MetaTensor& x,
out
->
share_lod
(
x
);
}
void
DiagEmbedInferMeta
(
const
MetaTensor
&
x
,
int
offset
,
int
dim1
,
int
dim2
,
MetaTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_GE
(
dim1
,
-
(
x_dims
.
size
()
+
1
),
phi
::
errors
::
OutOfRange
(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim1
));
PADDLE_ENFORCE_LE
(
dim1
,
x_dims
.
size
(),
phi
::
errors
::
OutOfRange
(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim1
));
PADDLE_ENFORCE_GE
(
dim2
,
-
(
x_dims
.
size
()
+
1
),
phi
::
errors
::
OutOfRange
(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim2
));
PADDLE_ENFORCE_LE
(
dim2
,
x_dims
.
size
(),
phi
::
errors
::
OutOfRange
(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld)."
,
-
(
x_dims
.
size
()
+
1
),
x_dims
.
size
(),
dim2
));
int
dim1_
=
dim1
<
0
?
x_dims
.
size
()
+
dim1
+
1
:
dim1
;
int
dim2_
=
dim2
<
0
?
x_dims
.
size
()
+
dim2
+
1
:
dim2
;
int
offset_
=
std
::
abs
(
offset
);
PADDLE_ENFORCE_NE
(
dim1_
,
dim2_
,
phi
::
errors
::
InvalidArgument
(
"diagonal dimensions should not be identical "
"%ld vs %ld."
,
dim1
,
dim2
));
int
new_dim_len
=
offset_
+
x_dims
[
x_dims
.
size
()
-
1
];
auto
sizes
=
vectorize
(
x_dims
);
sizes
.
pop_back
();
sizes
.
insert
(
sizes
.
begin
()
+
std
::
min
(
dim1_
,
dim2_
),
new_dim_len
);
sizes
.
insert
(
sizes
.
begin
()
+
std
::
max
(
dim1_
,
dim2_
),
new_dim_len
);
out
->
set_dims
(
phi
::
make_ddim
(
sizes
));
out
->
set_dtype
(
x
.
dtype
());
}
void
DiagInferMeta
(
const
MetaTensor
&
x
,
int
offset
,
float
padding_value
,
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
41f11d29
...
...
@@ -71,6 +71,9 @@ void CumInferMeta(const MetaTensor& x,
bool
reverse
,
MetaTensor
*
out
);
void
DiagEmbedInferMeta
(
const
MetaTensor
&
x
,
int
offset
,
int
dim1
,
int
dim2
,
MetaTensor
*
out
);
void
DiagInferMeta
(
const
MetaTensor
&
x
,
int
offset
,
float
padding_value
,
...
...
paddle/
fluid/operators/diag_embed_op.cu
→
paddle/
phi/kernels/cpu/diag_embed_kernel.cc
浏览文件 @
41f11d29
// Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,19 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/phi/kernels/diag_embed_kernel.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_embed_op.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/diag_embed_impl.h"
namespace
ops
=
paddle
::
operators
;
namespace
platform
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
diag_embed
,
ops
::
DiagEmbedKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
DiagEmbedKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
DiagEmbedKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
DiagEmbedKernel
<
paddle
::
platform
::
CUDADeviceContext
,
platform
::
float16
>
,
ops
::
DiagEmbedKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
PD_REGISTER_KERNEL
(
diag_embed
,
CPU
,
ALL_LAYOUT
,
phi
::
DiagEmbedKernel
,
int
,
int64_t
,
float
,
double
)
{}
paddle/phi/kernels/diag_embed_kernel.h
0 → 100644
浏览文件 @
41f11d29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DiagEmbedKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
offset
,
int
dim1
,
int
dim2
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/gpu/diag_embed_kernel.cu
0 → 100644
浏览文件 @
41f11d29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_embed_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/diag_embed_impl.h"
PD_REGISTER_KERNEL
(
diag_embed
,
GPU
,
ALL_LAYOUT
,
phi
::
DiagEmbedKernel
,
int
,
int64_t
,
float
,
double
)
{}
paddle/
fluid/operators/diag_embed_op
.h
→
paddle/
phi/kernels/impl/diag_embed_impl
.h
浏览文件 @
41f11d29
// Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -14,15 +14,19 @@
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include "paddle/phi/kernels/diag_embed_kernel.h"
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
phi
{
template
<
typename
T
>
struct
DiagEmbedFunctor
{
...
...
@@ -62,69 +66,64 @@ struct DiagEmbedFunctor {
const
int64_t
*
strides_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
DiagEmbedKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"Input"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
int64_t
offset
=
context
.
Attr
<
int
>
(
"offset"
);
const
int64_t
dim1
=
context
.
Attr
<
int
>
(
"dim1"
);
const
int64_t
dim2
=
context
.
Attr
<
int
>
(
"dim2"
);
auto
*
input_data
=
input
->
data
<
T
>
();
template
<
typename
T
,
typename
Context
>
void
DiagEmbedKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
offset
,
int
dim1
,
int
dim2
,
DenseTensor
*
out
)
{
auto
*
input_data
=
x
.
data
<
T
>
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
out
,
static_cast
<
T
>
(
0.0
));
set_zero
(
dev_ctx
,
out
,
static_cast
<
T
>
(
0.0
));
auto
out_dims
=
out
->
dims
();
int
dim1_
=
dim1
<
0
?
out_dims
.
size
()
+
dim1
:
dim1
;
int
dim2_
=
dim2
<
0
?
out_dims
.
size
()
+
dim2
:
dim2
;
auto
stride
=
phi
::
stride
(
out_dims
);
int64_t
diag_size
;
int64_t
storage_offset
=
0
;
if
(
offset
>=
0
)
{
int64_t
dim
=
out_dims
[
dim2_
]
-
offset
;
diag_size
=
std
::
max
<
int64_t
>
(
std
::
min
(
out_dims
[
dim1_
],
dim
),
0
);
}
else
{
int64_t
dim
=
out_dims
[
dim1_
]
+
offset
;
diag_size
=
std
::
max
<
int64_t
>
(
std
::
min
(
dim
,
out_dims
[
dim2_
]),
0
);
}
if
(
diag_size
==
0
)
{
// skip
}
else
if
(
offset
>=
0
)
{
storage_offset
+=
offset
*
stride
[
dim2_
];
}
else
{
storage_offset
-=
offset
*
stride
[
dim1_
];
}
auto
strides
=
vectorize
(
stride
);
strides
.
erase
(
strides
.
begin
()
+
std
::
max
(
dim1_
,
dim2_
));
strides
.
erase
(
strides
.
begin
()
+
std
::
min
(
dim1_
,
dim2_
));
strides
.
push_back
(
stride
[
dim1_
]
+
stride
[
dim2_
]);
const
auto
dims
=
vectorize
(
input
->
dims
());
auto
out_dims
=
out
->
dims
();
int
dim1_
=
dim1
<
0
?
out_dims
.
size
()
+
dim1
:
dim1
;
int
dim2_
=
dim2
<
0
?
out_dims
.
size
()
+
dim2
:
dim2
;
auto
stride
=
phi
::
stride
(
out_dims
);
int64_t
diag_size
;
int64_t
storage_offset
=
0
;
if
(
offset
>=
0
)
{
int64_t
dim
=
out_dims
[
dim2_
]
-
offset
;
diag_size
=
std
::
max
<
int64_t
>
(
std
::
min
(
out_dims
[
dim1_
],
dim
),
0
);
}
else
{
int64_t
dim
=
out_dims
[
dim1_
]
+
offset
;
diag_size
=
std
::
max
<
int64_t
>
(
std
::
min
(
dim
,
out_dims
[
dim2_
]),
0
);
}
if
(
diag_size
==
0
)
{
// skip
}
else
if
(
offset
>=
0
)
{
storage_offset
+=
offset
*
stride
[
dim2_
];
}
else
{
storage_offset
-=
offset
*
stride
[
dim1_
];
}
auto
strides
=
vectorize
(
stride
);
strides
.
erase
(
strides
.
begin
()
+
std
::
max
(
dim1_
,
dim2_
));
strides
.
erase
(
strides
.
begin
()
+
std
::
min
(
dim1_
,
dim2_
));
strides
.
push_back
(
stride
[
dim1_
]
+
stride
[
dim2_
]);
const
auto
dims
=
vectorize
(
x
.
dims
());
#if defined(__NVCC__) || defined(__HIPCC__)
thrust
::
device_vector
<
int64_t
>
dims_vec
(
dims
);
const
int64_t
*
dims_arr
=
thrust
::
raw_pointer_cast
(
dims_vec
.
data
());
thrust
::
device_vector
<
int64_t
>
strides_vec
(
strides
);
const
int64_t
*
strides_arr
=
thrust
::
raw_pointer_cast
(
strides_vec
.
data
());
thrust
::
device_vector
<
int64_t
>
dims_vec
(
dims
);
const
int64_t
*
dims_arr
=
thrust
::
raw_pointer_cast
(
dims_vec
.
data
());
thrust
::
device_vector
<
int64_t
>
strides_vec
(
strides
);
const
int64_t
*
strides_arr
=
thrust
::
raw_pointer_cast
(
strides_vec
.
data
());
#else
const
int64_t
*
dims_arr
=
dims
.
data
();
const
int64_t
*
strides_arr
=
strides
.
data
();
const
int64_t
*
dims_arr
=
dims
.
data
();
const
int64_t
*
strides_arr
=
strides
.
data
();
#endif
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
input
->
numel
());
DiagEmbedFunctor
<
T
>
functor
(
input_data
,
input
->
numel
(),
dims_arr
,
storage_offset
,
dims
.
size
(),
out_data
,
strides_arr
);
for_range
(
functor
);
}
};
}
// namespace operators
}
// namespace paddle
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
x
.
numel
());
DiagEmbedFunctor
<
T
>
functor
(
input_data
,
x
.
numel
(),
dims_arr
,
storage_offset
,
dims
.
size
(),
out_data
,
strides_arr
);
for_range
(
functor
);
}
}
// namespace phi
python/paddle/fluid/tests/unittests/test_diag_embed.py
浏览文件 @
41f11d29
...
...
@@ -27,11 +27,12 @@ class TestDiagEmbedOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"diag_embed"
self
.
python_api
=
F
.
diag_embed
self
.
init_config
()
self
.
outputs
=
{
'Out'
:
self
.
target
}
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
def
init_config
(
self
):
self
.
case
=
np
.
random
.
randn
(
2
,
3
).
astype
(
'float32'
)
...
...
python/paddle/nn/functional/extension.py
浏览文件 @
41f11d29
...
...
@@ -98,12 +98,18 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
# [[ 0. , 0. , 0. , 0. ],
# [ 0. , 0. , 0. , 0. ]]]
"""
inputs
=
{
'Input'
:
[
input
]}
attrs
=
{
'offset'
:
offset
,
'dim1'
:
dim1
,
'dim2'
:
dim2
}
if
not
isinstance
(
input
,
Variable
):
input
=
assign
(
input
)
if
in_dygraph_mode
():
return
_C_ops
.
final_state_diag_embed
(
input
,
offset
,
dim1
,
dim2
)
elif
in_dynamic_mode
():
return
_C_ops
.
diag_embed
(
input
,
"offset"
,
offset
,
"dim1"
,
dim1
,
"dim2"
,
dim2
)
inputs
=
{
'Input'
:
[
input
]}
attrs
=
{
'offset'
:
offset
,
'dim1'
:
dim1
,
'dim2'
:
dim2
}
def
__check_input
(
input
,
offset
,
dim1
,
dim2
):
check_dtype
(
input
.
dtype
,
'Input'
,
[
'int32'
,
'int64'
,
'float16'
,
'float32'
,
'float64'
],
...
...
@@ -129,8 +135,7 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
"dim1 and dim2 cannot be the same dimension."
\
"But received dim1 = %d, dim2 = %d
\n
"
%
(
dim1
,
dim2
)
if
not
in_dynamic_mode
():
__check_input
(
input
,
offset
,
dim1
,
dim2
)
__check_input
(
input
,
offset
,
dim1
,
dim2
)
helper
=
LayerHelper
(
"diag_embed"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录