Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f58fe6d3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f58fe6d3
编写于
1月 02, 2018
作者:
C
chengduo
提交者:
GitHub
1月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6601 from chengduoZH/profiling/cosine_op
Refine cos-sim-op
上级
0bd7f97b
812c5f60
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
332 addition
and
66 deletion
+332
-66
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-0
paddle/operators/cos_sim_op.h
paddle/operators/cos_sim_op.h
+51
-66
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-0
paddle/operators/math/cos_sim_functor.cc
paddle/operators/math/cos_sim_functor.cc
+48
-0
paddle/operators/math/cos_sim_functor.cu
paddle/operators/math/cos_sim_functor.cu
+64
-0
paddle/operators/math/cos_sim_functor.h
paddle/operators/math/cos_sim_functor.h
+166
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
f58fe6d3
...
...
@@ -229,6 +229,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library
(
conv_transpose_op DEPS vol2col
)
op_library
(
gru_op DEPS sequence2batch gru_compute
)
op_library
(
recurrent_op DEPS executor
)
op_library
(
cos_sim_op DEPS cos_sim_functor
)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library
(
save_op DEPS lod_tensor
)
op_library
(
load_op DEPS lod_tensor
)
...
...
paddle/operators/cos_sim_op.h
浏览文件 @
f58fe6d3
...
...
@@ -13,19 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
DeviceContext
,
typename
T
>
class
CosSimKernel
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel<T> {
out_x_norm
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_y_norm
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// convert Tensor to Eigen Tensor
int
rows_x
=
in_x
->
dims
()[
0
];
int
rows_y
=
in_y
->
dims
()[
0
];
auto
x
=
EigenMatrix
<
T
>::
Reshape
(
*
in_x
,
1
);
auto
y
=
EigenMatrix
<
T
>::
Reshape
(
*
in_y
,
1
);
auto
z
=
EigenVector
<
T
>::
Flatten
(
*
out_z
);
auto
x_norm
=
EigenVector
<
T
>::
Flatten
(
*
out_x_norm
);
auto
y_norm
=
EigenVector
<
T
>::
Flatten
(
*
out_y_norm
);
// compute
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
row_along
=
Eigen
::
array
<
int
,
1
>
({{
1
}});
x_norm
.
device
(
place
)
=
x
.
square
().
sum
(
row_along
).
sqrt
();
y_norm
.
device
(
place
)
=
y
.
square
().
sum
(
row_along
).
sqrt
();
int
cols
=
framework
::
product
(
in_x
->
dims
())
/
rows_x
;
if
(
rows_x
==
rows_y
)
{
auto
xy
=
(
x
*
y
).
sum
(
Eigen
::
array
<
int
,
1
>
({{
1
}}));
z
.
device
(
place
)
=
xy
/
x_norm
/
y_norm
;
math
::
CosSimFunctor
<
T
,
true
>
functor
(
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
out_x_norm
->
data
<
T
>
(),
out_y_norm
->
data
<
T
>
(),
out_z
->
data
<
T
>
(),
cols
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
context
.
device_context
()),
rows_x
);
for_range
(
functor
);
}
else
{
Eigen
::
DSizes
<
int
,
2
>
bcast
(
rows_x
,
1
);
auto
xy
=
(
x
*
y
.
broadcast
(
bcast
)).
sum
(
row_along
);
z
.
device
(
place
)
=
xy
/
x_norm
/
y_norm
.
broadcast
(
bcast
);
math
::
CosSimFunctor
<
T
,
false
>
functor
(
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
out_x_norm
->
data
<
T
>
(),
out_y_norm
->
data
<
T
>
(),
out_z
->
data
<
T
>
(),
cols
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
context
.
device_context
()),
rows_x
);
for_range
(
functor
);
}
}
};
...
...
@@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto
*
out_grad_y
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
in_grad_z
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
// convert Tensor to Eigen Tensor
auto
x
=
EigenMatrix
<
T
>::
Reshape
(
*
in_x
,
1
);
auto
y
=
EigenMatrix
<
T
>::
Reshape
(
*
in_y
,
1
);
auto
z
=
EigenMatrix
<
T
>::
Reshape
(
*
in_z
,
1
);
auto
x_norm
=
EigenMatrix
<
T
>::
Reshape
(
*
in_x_norm
,
1
);
auto
y_norm
=
EigenMatrix
<
T
>::
Reshape
(
*
in_y_norm
,
1
);
auto
dz
=
EigenMatrix
<
T
>::
Reshape
(
*
in_grad_z
,
1
);
// compute gradident
int
rows_x
=
in_x
->
dims
()[
0
];
int
rows_y
=
in_y
->
dims
()[
0
];
int
cols
=
framework
::
product
(
in_x
->
dims
())
/
rows_x
;
Eigen
::
DSizes
<
int
,
2
>
bcast_cols
(
1
,
cols
);
auto
z_bcast
=
z
.
broadcast
(
bcast_cols
);
auto
dz_bcast
=
dz
.
broadcast
(
bcast_cols
);
auto
x_snorm_bcast
=
x_norm
.
square
().
eval
().
broadcast
(
bcast_cols
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
if
(
rows_x
==
rows_y
)
{
auto
y_snorm_bcast
=
y_norm
.
square
().
eval
().
broadcast
(
bcast_cols
);
auto
norm_prod_bcast
=
(
x_norm
*
y_norm
).
eval
().
broadcast
(
bcast_cols
);
// compute dx
if
(
out_grad_x
)
{
out_grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dx
=
EigenMatrix
<
T
>::
Reshape
(
*
out_grad_x
,
1
);
auto
grad
=
y
/
norm_prod_bcast
-
z_bcast
*
x
/
x_snorm_bcast
;
dx
.
device
(
place
)
=
dz_bcast
*
grad
;
math
::
CosSimGradFunctor
<
T
>
functor
(
in_x_norm
->
data
<
T
>
(),
in_y_norm
->
data
<
T
>
(),
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
in_z
->
data
<
T
>
(),
in_grad_z
->
data
<
T
>
(),
out_grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
cols
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
context
.
device_context
()),
rows_x
);
for_range
(
functor
);
}
// compute dy
if
(
out_grad_y
)
{
out_grad_y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dy
=
EigenMatrix
<
T
>::
Reshape
(
*
out_grad_y
,
1
);
auto
grad
=
x
/
norm_prod_bcast
-
z_bcast
*
y
/
y_snorm_bcast
;
dy
.
device
(
place
)
=
dz_bcast
*
grad
;
math
::
CosSimGradFunctor
<
T
>
functor
(
in_y_norm
->
data
<
T
>
(),
in_x_norm
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
in_x
->
data
<
T
>
(),
in_z
->
data
<
T
>
(),
in_grad_z
->
data
<
T
>
(),
out_grad_y
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
cols
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
context
.
device_context
()),
rows_x
);
for_range
(
functor
);
}
}
else
{
Eigen
::
DSizes
<
int
,
2
>
bcast_rows
(
rows_x
,
1
);
Eigen
::
DSizes
<
int
,
2
>
bcast_rows_cols
(
rows_x
,
cols
);
auto
y_bcast
=
y
.
broadcast
(
bcast_rows
);
auto
y_snorm_bcast
=
y_norm
.
square
().
eval
().
broadcast
(
bcast_rows_cols
);
auto
norm_prod_bcast
=
(
x_norm
*
y_norm
.
eval
().
broadcast
(
bcast_rows
))
.
eval
()
.
broadcast
(
bcast_cols
);
// compute dx
if
(
out_grad_x
)
{
out_grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dx
=
EigenMatrix
<
T
>::
Reshape
(
*
out_grad_x
,
1
);
auto
grad
=
y_bcast
/
norm_prod_bcast
-
z_bcast
*
x
/
x_snorm_bcast
;
dx
.
device
(
place
)
=
dz_bcast
*
grad
;
math
::
CosSimDxFunctor
<
T
>
functor
(
in_x_norm
->
data
<
T
>
(),
in_y_norm
->
data
<
T
>
(),
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
in_z
->
data
<
T
>
(),
in_grad_z
->
data
<
T
>
(),
out_grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
()),
cols
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
context
.
device_context
()),
rows_x
);
for_range
(
functor
);
}
// compute dy
if
(
out_grad_y
)
{
out_grad_y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dy
=
EigenVector
<
T
>::
Flatten
(
*
out_grad_y
);
auto
grad
=
x
/
norm_prod_bcast
-
z_bcast
*
y_bcast
/
y_snorm_bcast
;
dy
.
device
(
place
)
=
(
dz_bcast
*
grad
).
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
out_grad_y
,
static_cast
<
T
>
(
0
));
math
::
CosSimDyFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
dev_ctx
,
in_x_norm
->
data
<
T
>
(),
in_y_norm
->
data
<
T
>
(),
in_x
->
data
<
T
>
(),
in_y
->
data
<
T
>
(),
in_z
->
data
<
T
>
(),
in_grad_z
->
data
<
T
>
(),
static_cast
<
size_t
>
(
rows_x
),
static_cast
<
size_t
>
(
cols
),
out_grad_y
->
data
<
T
>
());
}
}
}
...
...
paddle/operators/math/CMakeLists.txt
浏览文件 @
f58fe6d3
...
...
@@ -16,6 +16,7 @@ if(WITH_GPU)
nv_library
(
maxouting SRCS maxouting.cc maxouting.cu DEPS device_context
)
nv_library
(
unpooling SRCS unpooling.cc unpooling.cu DEPS device_context
)
nv_library
(
gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function
)
nv_library
(
cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto
)
cc_library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
...
...
@@ -30,6 +31,7 @@ else()
cc_library
(
maxouting SRCS maxouting.cc DEPS device_context
)
cc_library
(
unpooling SRCS unpooling.cc DEPS device_context
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context
)
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
paddle/operators/math/cos_sim_functor.cc
0 → 100644
浏览文件 @
f58fe6d3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
struct
CosSimDyFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
{
for
(
size_t
row_id
=
0
;
row_id
<
rows
;
++
row_id
)
{
auto
xy_norm_prod
=
x_norm
[
row_id
]
*
y_norm
[
0
];
auto
dz_data
=
dz
[
row_id
];
auto
z_data
=
z
[
row_id
];
auto
*
x_data
=
x
+
cols
*
row_id
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
y_norm_square
=
y_norm
[
0
]
*
y_norm
[
0
];
auto
reciprocal_y_norm_square
=
1
/
y_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols
;
++
i
)
{
dy
[
i
]
+=
dz_data
*
(
x_data
[
i
]
*
reciprocal_xy_norm_prod
-
z_data
*
y
[
i
]
*
reciprocal_y_norm_square
);
}
}
}
};
template
class
CosSimDyFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
CosSimDyFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/cos_sim_functor.cu
0 → 100644
浏览文件 @
f58fe6d3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
__global__
void
CosSimDyKernel
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
{
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
T
y_norm_data
=
y_norm
[
0
];
for
(
int
row_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
row_id
<
rows
;
row_id
+=
grid_size
)
{
T
xy_norm_prod
=
x_norm
[
row_id
]
*
y_norm_data
;
T
dz_data
=
dz
[
row_id
];
T
z_data
=
z
[
row_id
];
const
T
*
x_data
=
x
+
cols
*
row_id
;
T
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
T
y_norm_square
=
y_norm_data
*
y_norm_data
;
T
reciprocal_y_norm_square
=
1
/
y_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols
;
++
i
)
{
T
dy_data
=
dz_data
*
(
x_data
[
i
]
*
reciprocal_xy_norm_prod
-
z_data
*
y
[
i
]
*
reciprocal_y_norm_square
);
platform
::
CudaAtomicAdd
(
dy
+
i
,
dy_data
);
}
}
}
template
<
typename
T
>
struct
CosSimDyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
{
const
int
block_size
=
512
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
(
rows
+
block_size
-
1
)
/
block_size
);
CosSimDyKernel
<
T
><<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
x_norm
,
y_norm
,
x
,
y
,
z
,
dz
,
rows
,
cols
,
dy
);
}
};
template
class
CosSimDyFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
CosSimDyFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/cos_sim_functor.h
0 → 100644
浏览文件 @
f58fe6d3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <stdlib.h>
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
,
bool
same_row
>
struct
CosSimFunctor
{
CosSimFunctor
(
const
T
*
x
,
const
T
*
y
,
T
*
x_norm
,
T
*
y_norm
,
T
*
z
,
int
cols
)
:
x_norm_
(
x_norm
),
y_norm_
(
y_norm
),
x_
(
x
),
y_
(
y
),
z_
(
z
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
row_id
)
const
{
auto
*
x
=
x_
+
cols_
*
row_id
;
T
xx
=
0
,
xy
=
0
,
yy
=
0
;
if
(
same_row
)
{
auto
*
y
=
y_
+
cols_
*
row_id
;
T
tep_x
,
tep_y
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
tep_x
=
x
[
i
];
tep_y
=
y
[
i
];
xx
+=
tep_x
*
tep_x
;
yy
+=
tep_y
*
tep_y
;
xy
+=
tep_x
*
tep_y
;
}
xx
=
sqrt
(
xx
);
yy
=
sqrt
(
yy
);
y_norm_
[
row_id
]
=
yy
;
x_norm_
[
row_id
]
=
xx
;
z_
[
row_id
]
=
xy
/
(
xx
*
yy
);
}
else
{
// This can be wrote in a better way.
T
tep_x
,
tep_y
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
tep_x
=
x
[
i
];
tep_y
=
y_
[
i
];
xx
+=
tep_x
*
tep_x
;
yy
+=
tep_y
*
tep_y
;
xy
+=
tep_x
*
tep_y
;
}
xx
=
sqrt
(
xx
);
yy
=
sqrt
(
yy
);
if
(
row_id
==
0
)
y_norm_
[
0
]
=
yy
;
x_norm_
[
row_id
]
=
xx
;
z_
[
row_id
]
=
xy
/
(
xx
*
yy
);
}
}
T
*
x_norm_
;
T
*
y_norm_
;
const
T
*
x_
;
const
T
*
y_
;
T
*
z_
;
const
size_t
cols_
;
};
template
<
typename
T
>
struct
CosSimGradFunctor
{
CosSimGradFunctor
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
T
*
dx
,
int
cols
)
:
x_norm_
(
x_norm
),
y_norm_
(
y_norm
),
x_
(
x
),
y_
(
y
),
z_
(
z
),
dz_
(
dz
),
dx_
(
dx
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
row_id
)
const
{
auto
x_norm_square
=
x_norm_
[
row_id
]
*
x_norm_
[
row_id
];
auto
xy_norm_prod
=
x_norm_
[
row_id
]
*
y_norm_
[
row_id
];
auto
dz
=
dz_
[
row_id
];
auto
z
=
z_
[
row_id
];
auto
*
dx
=
dx_
+
cols_
*
row_id
;
auto
*
x
=
x_
+
cols_
*
row_id
;
auto
*
y
=
y_
+
cols_
*
row_id
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
reciprocal_x_norm_square
=
1
/
x_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
dx
[
i
]
=
dz
*
(
y
[
i
]
*
reciprocal_xy_norm_prod
-
z
*
x
[
i
]
*
reciprocal_x_norm_square
);
}
}
const
T
*
x_norm_
;
const
T
*
y_norm_
;
const
T
*
x_
;
const
T
*
y_
;
const
T
*
z_
;
const
T
*
dz_
;
T
*
dx_
;
const
size_t
cols_
;
};
template
<
typename
T
>
struct
CosSimDxFunctor
{
CosSimDxFunctor
(
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
T
*
dx
,
int
cols
)
:
x_norm_
(
x_norm
),
y_norm_
(
y_norm
),
x_
(
x
),
y_
(
y
),
z_
(
z
),
dz_
(
dz
),
dx_
(
dx
),
cols_
(
static_cast
<
size_t
>
(
cols
))
{}
inline
HOSTDEVICE
void
operator
()(
size_t
row_id
)
const
{
auto
xy_norm_prod
=
x_norm_
[
row_id
]
*
y_norm_
[
0
];
auto
dz
=
dz_
[
row_id
];
auto
z
=
z_
[
row_id
];
auto
*
x
=
x_
+
cols_
*
row_id
;
auto
reciprocal_xy_norm_prod
=
1
/
xy_norm_prod
;
auto
x_norm_square
=
x_norm_
[
row_id
]
*
x_norm_
[
row_id
];
auto
*
dx
=
dx_
+
cols_
*
row_id
;
auto
reciprocal_x_norm_square
=
1
/
x_norm_square
;
for
(
size_t
i
=
0
;
i
<
cols_
;
++
i
)
{
dx
[
i
]
=
dz
*
(
y_
[
i
]
*
reciprocal_xy_norm_prod
-
z
*
x
[
i
]
*
reciprocal_x_norm_square
);
}
}
const
T
*
x_norm_
;
const
T
*
y_norm_
;
const
T
*
x_
;
const
T
*
y_
;
const
T
*
z_
;
const
T
*
dz_
;
T
*
dx_
;
const
size_t
cols_
;
};
template
<
typename
DeviceContext
,
typename
T
>
struct
CosSimDyFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
x_norm
,
const
T
*
y_norm
,
const
T
*
x
,
const
T
*
y
,
const
T
*
z
,
const
T
*
dz
,
const
size_t
rows
,
const
size_t
cols
,
T
*
dy
)
const
;
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录