Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2553af4f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2553af4f
编写于
2月 25, 2022
作者:
F
furnace
提交者:
GitHub
2月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] mv kernel (#39861)
[Phi] mv kernel
上级
22f84122
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
346 addition
and
207 deletion
+346
-207
paddle/fluid/operators/mv_op.cc
paddle/fluid/operators/mv_op.cc
+7
-8
paddle/fluid/operators/mv_op.cu
paddle/fluid/operators/mv_op.cu
+0
-94
paddle/fluid/operators/mv_op.h
paddle/fluid/operators/mv_op.h
+0
-105
paddle/phi/kernels/cpu/mv_grad_kernel.cc
paddle/phi/kernels/cpu/mv_grad_kernel.cc
+72
-0
paddle/phi/kernels/cpu/mv_kernel.cc
paddle/phi/kernels/cpu/mv_kernel.cc
+22
-0
paddle/phi/kernels/gpu/mv_grad_kernel.cu
paddle/phi/kernels/gpu/mv_grad_kernel.cu
+83
-0
paddle/phi/kernels/gpu/mv_kernel.cu
paddle/phi/kernels/gpu/mv_kernel.cu
+22
-0
paddle/phi/kernels/impl/mv_kernel_impl.h
paddle/phi/kernels/impl/mv_kernel_impl.h
+45
-0
paddle/phi/kernels/mv_grad_kernel.h
paddle/phi/kernels/mv_grad_kernel.h
+35
-0
paddle/phi/kernels/mv_kernel.h
paddle/phi/kernels/mv_kernel.h
+27
-0
paddle/phi/ops/compat/mv_sig.cc
paddle/phi/ops/compat/mv_sig.cc
+33
-0
未找到文件。
paddle/fluid/operators/mv_op.cc
浏览文件 @
2553af4f
...
@@ -12,7 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/mv_op.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -116,10 +122,3 @@ REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker,
...
@@ -116,10 +122,3 @@ REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker,
ops
::
MVOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
MVOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
MVOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
MVOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
mv_grad
,
ops
::
MVOpGrad
);
REGISTER_OPERATOR
(
mv_grad
,
ops
::
MVOpGrad
);
REGISTER_OP_CPU_KERNEL
(
mv
,
ops
::
MVKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MVKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
mv_grad
,
ops
::
MVGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MVGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/mv_op.cu
已删除
100644 → 0
浏览文件 @
22f84122
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
MVGradDxCUDAKernel
(
const
int
m
,
const
int
n
,
const
T
*
dout
,
const
T
*
vec
,
T
*
dx
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
m
*
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
idx
/
n
;
int
j
=
idx
%
n
;
dx
[
idx
]
=
dout
[
i
]
*
vec
[
j
];
}
}
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// dX = | dOut Vec^T
// dVec = | X^T dOut
template
<
typename
T
>
class
MVGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
vec
=
context
.
Input
<
framework
::
Tensor
>
(
"Vec"
);
auto
*
dout
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dvec
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Vec"
));
auto
dim_x
=
x
->
dims
();
int
m
=
dim_x
[
0
];
int
n
=
dim_x
[
1
];
// get data ptr
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
vec_data
=
vec
->
data
<
T
>
();
const
T
*
dout_data
=
dout
->
data
<
T
>
();
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx
);
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
config
=
GetGpuLaunchConfig1D
(
dev_ctx
,
m
*
n
);
if
(
dx
)
{
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
MVGradDxCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
m
,
n
,
dout_data
,
vec_data
,
dx_data
);
}
if
(
dvec
)
{
T
*
dvec_data
=
dvec
->
mutable_data
<
T
>
(
context
.
GetPlace
());
blas
.
GEMV
(
true
,
dim_x
[
0
],
dim_x
[
1
],
static_cast
<
T
>
(
1
),
x_data
,
dout_data
,
static_cast
<
T
>
(
0
),
dvec_data
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
mv
,
ops
::
MVKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MVKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
mv_grad
,
ops
::
MVGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MVGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/mv_op.h
已删除
100644 → 0
浏览文件 @
22f84122
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
MVKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
vec
=
context
.
Input
<
framework
::
Tensor
>
(
"Vec"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dim_x
=
x
->
dims
();
// get data ptr
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
vec_data
=
vec
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
phi
::
funcs
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
blas
.
GEMV
(
false
,
dim_x
[
0
],
dim_x
[
1
],
static_cast
<
T
>
(
1
),
x_data
,
vec_data
,
static_cast
<
T
>
(
0
),
out_data
);
}
};
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// dX = | dOut vec^T
// dVec = | X^T dOut
template
<
typename
DeviceContext
,
typename
T
>
class
MVGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
vec
=
context
.
Input
<
framework
::
Tensor
>
(
"Vec"
);
auto
*
dout
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dvec
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Vec"
));
auto
dim_x
=
x
->
dims
();
int
m
=
dim_x
[
0
];
int
n
=
dim_x
[
1
];
// get data ptr
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
vec_data
=
vec
->
data
<
T
>
();
const
T
*
dout_data
=
dout
->
data
<
T
>
();
if
(
dx
)
{
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
dx_data
[
i
*
n
+
j
]
=
dout_data
[
i
]
*
vec_data
[
j
];
}
}
}
if
(
dvec
)
{
T
*
dvec_data
=
dvec
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
phi
::
funcs
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
blas
.
GEMV
(
true
,
dim_x
[
0
],
dim_x
[
1
],
static_cast
<
T
>
(
1
),
x_data
,
dout_data
,
static_cast
<
T
>
(
0
),
dvec_data
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/kernels/cpu/mv_grad_kernel.cc
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MvGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
vec
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
vec_grad
)
{
auto
dout
=
out_grad
;
auto
dx
=
x_grad
;
auto
dvec
=
vec_grad
;
auto
dim_x
=
x
.
dims
();
int
m
=
dim_x
[
0
];
int
n
=
dim_x
[
1
];
// get data ptr
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
vec_data
=
vec
.
data
<
T
>
();
const
T
*
dout_data
=
dout
.
data
<
T
>
();
if
(
dx
)
{
T
*
dx_data
=
dev_ctx
.
template
Alloc
<
T
>(
dx
);
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
dx_data
[
i
*
n
+
j
]
=
dout_data
[
i
]
*
vec_data
[
j
];
}
}
}
if
(
dvec
)
{
T
*
dvec_data
=
dev_ctx
.
template
Alloc
<
T
>(
dvec
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
blas
.
GEMV
(
true
,
dim_x
[
0
],
dim_x
[
1
],
static_cast
<
T
>
(
1
),
x_data
,
dout_data
,
static_cast
<
T
>
(
0
),
dvec_data
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
mv_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
MvGradKernel
,
float
,
double
)
{
}
paddle/phi/kernels/cpu/mv_kernel.cc
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/mv_kernel_impl.h"
PD_REGISTER_KERNEL
(
mv
,
CPU
,
ALL_LAYOUT
,
phi
::
MvKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/mv_grad_kernel.cu
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
template
<
typename
T
>
__global__
void
MVGradDxCUDAKernel
(
const
int
m
,
const
int
n
,
const
T
*
dout
,
const
T
*
vec
,
T
*
dx
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
m
*
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
idx
/
n
;
int
j
=
idx
%
n
;
dx
[
idx
]
=
dout
[
i
]
*
vec
[
j
];
}
}
template
<
typename
T
,
typename
Context
>
void
MvGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
vec
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
vec_grad
)
{
auto
dout
=
out_grad
;
auto
dx
=
x_grad
;
auto
dvec
=
vec_grad
;
auto
dim_x
=
x
.
dims
();
int
m
=
dim_x
[
0
];
int
n
=
dim_x
[
1
];
// get data ptr
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
vec_data
=
vec
.
data
<
T
>
();
const
T
*
dout_data
=
dout
.
data
<
T
>
();
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
auto
stream
=
dev_ctx
.
stream
();
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
m
*
n
);
if
(
dx
)
{
T
*
dx_data
=
dev_ctx
.
template
Alloc
<
T
>(
dx
);
MVGradDxCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
m
,
n
,
dout_data
,
vec_data
,
dx_data
);
}
if
(
dvec
)
{
T
*
dvec_data
=
dev_ctx
.
template
Alloc
<
T
>(
dvec
);
blas
.
GEMV
(
true
,
dim_x
[
0
],
dim_x
[
1
],
static_cast
<
T
>
(
1
),
x_data
,
dout_data
,
static_cast
<
T
>
(
0
),
dvec_data
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
mv_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
MvGradKernel
,
float
,
double
)
{
}
paddle/phi/kernels/gpu/mv_kernel.cu
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/mv_kernel_impl.h"
PD_REGISTER_KERNEL
(
mv
,
GPU
,
ALL_LAYOUT
,
phi
::
MvKernel
,
float
,
double
)
{}
paddle/phi/kernels/impl/mv_kernel_impl.h
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MvKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
vec
,
DenseTensor
*
out
)
{
auto
dim_x
=
x
.
dims
();
// get data ptr
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
vec_data
=
vec
.
data
<
T
>
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
blas
.
GEMV
(
false
,
dim_x
[
0
],
dim_x
[
1
],
static_cast
<
T
>
(
1
),
x_data
,
vec_data
,
static_cast
<
T
>
(
0
),
out_data
);
}
}
// namespace phi
paddle/phi/kernels/mv_grad_kernel.h
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
// Using dimensional constraints on matrix multiplication, it is
// straight-forward to check the following table for when X and Y
// are both matrices.
//
// dX = | dOut vec^T
// dVec = | X^T dOut
template
<
typename
T
,
typename
Context
>
void
MvGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
vec
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
vec_grad
);
}
// namespace phi
paddle/phi/kernels/mv_kernel.h
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MvKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
vec
,
DenseTensor
*
out
);
}
// namepsace phi
paddle/phi/ops/compat/mv_sig.cc
0 → 100644
浏览文件 @
2553af4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
MvOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"mv"
,
{
"X"
,
"Vec"
},
{},
{
"Out"
});
}
KernelSignature
MvGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"mv_grad"
,
{
"X"
,
"Vec"
,
GradVarName
(
"Out"
)},
{},
{
GradVarName
(
"X"
),
GradVarName
(
"Vec"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
mv
,
phi
::
MvOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
mv_grad
,
phi
::
MvGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录