Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
95089204
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
95089204
编写于
5月 26, 2020
作者:
S
ShenLiang
提交者:
GitHub
5月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix conflict, test=develop (#24238)
上级
c3c61d34
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
790 addition
and
168 deletion
+790
-168
paddle/fluid/framework/fleet/box_wrapper.cc
paddle/fluid/framework/fleet/box_wrapper.cc
+90
-128
paddle/fluid/framework/fleet/box_wrapper.cu
paddle/fluid/framework/fleet/box_wrapper.cu
+107
-22
paddle/fluid/framework/fleet/box_wrapper.h
paddle/fluid/framework/fleet/box_wrapper.h
+52
-9
paddle/fluid/framework/fleet/box_wrapper_impl.h
paddle/fluid/framework/fleet/box_wrapper_impl.h
+163
-0
paddle/fluid/operators/pull_box_extended_sparse_op.cc
paddle/fluid/operators/pull_box_extended_sparse_op.cc
+157
-0
paddle/fluid/operators/pull_box_extended_sparse_op.cu
paddle/fluid/operators/pull_box_extended_sparse_op.cu
+46
-0
paddle/fluid/operators/pull_box_extended_sparse_op.h
paddle/fluid/operators/pull_box_extended_sparse_op.h
+119
-0
paddle/fluid/operators/pull_box_sparse_op.h
paddle/fluid/operators/pull_box_sparse_op.h
+2
-2
paddle/fluid/pybind/box_helper_py.cc
paddle/fluid/pybind/box_helper_py.cc
+2
-2
python/paddle/fluid/contrib/layers/nn.py
python/paddle/fluid/contrib/layers/nn.py
+49
-1
python/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py
...n/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py
+3
-4
未找到文件。
paddle/fluid/framework/fleet/box_wrapper.cc
浏览文件 @
95089204
...
...
@@ -28,6 +28,8 @@ std::shared_ptr<BoxWrapper> BoxWrapper::s_instance_ = nullptr;
cudaStream_t
BoxWrapper
::
stream_list_
[
8
];
std
::
shared_ptr
<
boxps
::
BoxPSBase
>
BoxWrapper
::
boxps_ptr_
=
nullptr
;
AfsManager
*
BoxWrapper
::
afs_manager
=
nullptr
;
int
BoxWrapper
::
embedx_dim_
=
8
;
int
BoxWrapper
::
expand_embed_dim_
=
0
;
void
BasicAucCalculator
::
compute
()
{
double
*
table
[
2
]
=
{
&
_table
[
0
][
0
],
&
_table
[
1
][
0
]};
...
...
@@ -57,6 +59,94 @@ void BasicAucCalculator::compute() {
_size
=
fp
+
tp
;
}
void
BoxWrapper
::
CheckEmbedSizeIsValid
(
int
embedx_dim
,
int
expand_embed_dim
)
{
PADDLE_ENFORCE_EQ
(
embedx_dim_
,
embedx_dim
,
platform
::
errors
::
InvalidArgument
(
"SetInstance(): invalid embedx_dim. "
"When embedx_dim = %d, but got %d."
,
embedx_dim_
,
embedx_dim
));
PADDLE_ENFORCE_EQ
(
expand_embed_dim_
,
expand_embed_dim
,
platform
::
errors
::
InvalidArgument
(
"SetInstance(): invalid expand_embed_dim. When "
"expand_embed_dim = %d, but got %d."
,
expand_embed_dim_
,
expand_embed_dim
));
}
void
BoxWrapper
::
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
expand_embed_dim
)
{
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define PULLSPARSE_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PullSparseCase<EmbedxDim, ExpandDim>(place, keys, values, slot_lengths, \
hidden_size, expand_embed_dim); \
} break
CheckEmbedSizeIsValid
(
hidden_size
-
3
,
expand_embed_dim
);
switch
(
hidden_size
-
3
)
{
EMBEDX_CASE
(
8
,
PULLSPARSE_CASE
(
0
);
PULLSPARSE_CASE
(
8
);
PULLSPARSE_CASE
(
64
););
EMBEDX_CASE
(
16
,
PULLSPARSE_CASE
(
0
););
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupport this embedding size [%d]"
,
hidden_size
-
3
));
}
#undef PULLSPARSE_CASE
#undef EMBEDX_CASE
}
void
BoxWrapper
::
PushSparseGrad
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
const
float
*>&
grad_values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int
batch_size
)
{
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define PUSHSPARSE_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PushSparseGradCase<EmbedxDim, ExpandDim>(place, keys, grad_values, \
slot_lengths, hidden_size, \
expand_embed_dim, batch_size); \
} break
CheckEmbedSizeIsValid
(
hidden_size
-
3
,
expand_embed_dim
);
switch
(
hidden_size
-
3
)
{
EMBEDX_CASE
(
8
,
PUSHSPARSE_CASE
(
0
);
PUSHSPARSE_CASE
(
8
);
PUSHSPARSE_CASE
(
64
););
EMBEDX_CASE
(
16
,
PUSHSPARSE_CASE
(
0
););
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupport this embedding size [%d]"
,
hidden_size
-
3
));
}
#undef PUSHSPARSE_CASE
#undef EMBEDX_CASE
}
void
BasicAucCalculator
::
calculate_bucket_error
()
{
double
last_ctr
=
-
1
;
double
impression_sum
=
0
;
...
...
@@ -128,134 +218,6 @@ void BoxWrapper::EndPass(bool need_save_delta) const {
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"EndPass failed in BoxPS."
));
}
void
BoxWrapper
::
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
)
{
VLOG
(
3
)
<<
"Begin PullSparse"
;
platform
::
Timer
all_timer
;
platform
::
Timer
pull_boxps_timer
;
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
AllocShared
(
place
,
total_length
*
sizeof
(
boxps
::
FeatureValueGpu
));
boxps
::
FeatureValueGpu
*
total_values_gpu
=
reinterpret_cast
<
boxps
::
FeatureValueGpu
*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Warning:: CPUPlace is not supported in PaddleBox now."
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
VLOG
(
3
)
<<
"Begin copy keys, key_num["
<<
total_length
<<
"]"
;
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
GetDeviceId
();
LoDTensor
&
total_keys_tensor
=
keys_tensor
[
device_id
];
uint64_t
*
total_keys
=
reinterpret_cast
<
uint64_t
*>
(
total_keys_tensor
.
mutable_data
<
int64_t
>
({
total_length
,
1
},
place
));
// construct slot_level lod info
auto
slot_lengths_lod
=
slot_lengths
;
for
(
size_t
i
=
1
;
i
<
slot_lengths_lod
.
size
();
i
++
)
{
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_key
=
memory
::
AllocShared
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_length
=
memory
::
AllocShared
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
uint64_t
**
gpu_keys
=
reinterpret_cast
<
uint64_t
**>
(
buf_key
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
cudaMemcpy
(
gpu_keys
,
keys
.
data
(),
keys
.
size
()
*
sizeof
(
uint64_t
*
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
gpu_len
,
slot_lengths_lod
.
data
(),
slot_lengths
.
size
()
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
this
->
CopyKeys
(
place
,
gpu_keys
,
total_keys
,
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
static_cast
<
int
>
(
total_length
));
VLOG
(
3
)
<<
"Begin call PullSparseGPU in BoxPS"
;
pull_boxps_timer
.
Start
();
int
ret
=
boxps_ptr_
->
PullSparseGPU
(
total_keys
,
total_values_gpu
,
static_cast
<
int
>
(
total_length
),
device_id
);
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"PullSparseGPU failed in BoxPS."
));
pull_boxps_timer
.
Pause
();
VLOG
(
3
)
<<
"Begin Copy result to tensor, total_length["
<<
total_length
<<
"]"
;
this
->
CopyForPull
(
place
,
gpu_keys
,
values
,
total_values_gpu
,
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
hidden_size
,
total_length
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddleBox: PullSparse Only Support CPUPlace or CUDAPlace Now."
));
}
all_timer
.
Pause
();
VLOG
(
1
)
<<
"PullSparse total costs: "
<<
all_timer
.
ElapsedSec
()
<<
" s, of which BoxPS costs: "
<<
pull_boxps_timer
.
ElapsedSec
()
<<
" s"
;
VLOG
(
3
)
<<
"End PullSparse"
;
}
void
BoxWrapper
::
PushSparseGrad
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
const
float
*>&
grad_values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
batch_size
)
{
VLOG
(
3
)
<<
"Begin PushSparseGrad"
;
platform
::
Timer
all_timer
;
platform
::
Timer
push_boxps_timer
;
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
AllocShared
(
place
,
total_length
*
sizeof
(
boxps
::
FeaturePushValueGpu
));
boxps
::
FeaturePushValueGpu
*
total_grad_values_gpu
=
reinterpret_cast
<
boxps
::
FeaturePushValueGpu
*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Warning:: CPUPlace is not supported in PaddleBox now."
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
GetDeviceId
();
LoDTensor
&
cached_total_keys_tensor
=
keys_tensor
[
device_id
];
uint64_t
*
total_keys
=
reinterpret_cast
<
uint64_t
*>
(
cached_total_keys_tensor
.
data
<
int64_t
>
());
VLOG
(
3
)
<<
"Begin copy grad tensor to boxps struct"
;
this
->
CopyForPush
(
place
,
grad_values
,
total_grad_values_gpu
,
slot_lengths
,
hidden_size
,
total_length
,
batch_size
);
VLOG
(
3
)
<<
"Begin call PushSparseGPU in BoxPS"
;
push_boxps_timer
.
Start
();
int
ret
=
boxps_ptr_
->
PushSparseGPU
(
total_keys
,
total_grad_values_gpu
,
static_cast
<
int
>
(
total_length
),
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
GetDeviceId
());
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"PushSparseGPU failed in BoxPS."
));
push_boxps_timer
.
Pause
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddleBox: PushSparseGrad Only Support CPUPlace or CUDAPlace Now."
));
}
all_timer
.
Pause
();
VLOG
(
1
)
<<
"PushSparseGrad total cost: "
<<
all_timer
.
ElapsedSec
()
<<
" s, of which BoxPS cost: "
<<
push_boxps_timer
.
ElapsedSec
()
<<
" s"
;
VLOG
(
3
)
<<
"End PushSparseGrad"
;
}
void
BoxWrapper
::
GetRandomReplace
(
const
std
::
vector
<
Record
>&
pass_data
)
{
VLOG
(
0
)
<<
"Begin GetRandomReplace"
;
size_t
ins_num
=
pass_data
.
size
();
...
...
paddle/fluid/framework/fleet/box_wrapper.cu
浏览文件 @
95089204
...
...
@@ -27,9 +27,12 @@ namespace framework {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
__global__
void
PullCopy
(
float
**
dest
,
const
boxps
::
FeatureValueGpu
*
src
,
const
int64_t
*
len
,
int
hidden
,
int
slot_num
,
int
total_len
,
uint64_t
**
keys
)
{
template
<
size_t
EMBEDX_DIM
,
size_t
EXPAND_EMBED_DIM
>
__global__
void
PullCopy
(
float
**
dest
,
const
boxps
::
FeatureValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*
src
,
const
int64_t
*
len
,
int
hidden
,
int
expand_dim
,
int
slot_num
,
int
total_len
,
uint64_t
**
keys
)
{
CUDA_KERNEL_LOOP
(
i
,
total_len
)
{
int
low
=
0
;
int
high
=
slot_num
-
1
;
...
...
@@ -52,15 +55,28 @@ __global__ void PullCopy(float** dest, const boxps::FeatureValueGpu* src,
*
(
dest
[
x
]
+
y
*
hidden
+
2
)
=
(
src
+
i
)
->
embed_w
;
}
if
((
src
+
i
)
->
embedding_size
==
0
||
*
(
keys
[
x
]
+
y
)
==
0
)
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
hidden
-
3
;
j
++
)
{
*
(
dest
[
x
]
+
y
*
hidden
+
3
+
j
)
=
0
;
}
}
else
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
hidden
-
3
;
j
++
)
{
*
(
dest
[
x
]
+
y
*
hidden
+
3
+
j
)
=
(
src
+
i
)
->
embedx
[
1
+
j
];
}
}
}
// process embed_expand
if
(
expand_dim
>
0
)
{
int
z
=
x
+
slot_num
;
if
((
src
+
i
)
->
embed_expand_size
[
0
]
==
0
||
*
(
keys
[
x
]
+
y
)
==
0
)
{
for
(
int
j
=
0
;
j
<
expand_dim
;
j
++
)
{
*
(
dest
[
z
]
+
y
*
expand_dim
+
j
)
=
0
;
}
}
else
{
for
(
int
j
=
0
;
j
<
expand_dim
;
j
++
)
{
*
(
dest
[
z
]
+
y
*
expand_dim
+
j
)
=
(
src
+
i
)
->
embed_expand
[
1
+
j
];
}
}
}
}
// end kernel loop
}
__global__
void
CopyKeysKernel
(
uint64_t
**
src_keys
,
uint64_t
*
dest_total_keys
,
...
...
@@ -82,9 +98,11 @@ __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
}
}
__global__
void
PushCopy
(
boxps
::
FeaturePushValueGpu
*
dest
,
float
**
src
,
int64_t
*
len
,
int
hidden
,
int
slot_num
,
int
total_len
,
int
bs
,
int
*
slot_vector
)
{
template
<
size_t
EMBEDX_DIM
,
size_t
EXPAND_EMBED_DIM
>
__global__
void
PushCopy
(
boxps
::
FeaturePushValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*
dest
,
float
**
src
,
int64_t
*
len
,
int
hidden
,
int
expand_dim
,
int
slot_num
,
int
total_len
,
int
bs
,
int
*
slot_vector
)
{
CUDA_KERNEL_LOOP
(
i
,
total_len
)
{
int
low
=
0
;
int
high
=
slot_num
-
1
;
...
...
@@ -101,18 +119,25 @@ __global__ void PushCopy(boxps::FeaturePushValueGpu* dest, float** src,
(
dest
+
i
)
->
show
=
*
(
src
[
x
]
+
y
*
hidden
);
(
dest
+
i
)
->
clk
=
*
(
src
[
x
]
+
y
*
hidden
+
1
);
(
dest
+
i
)
->
embed_g
=
*
(
src
[
x
]
+
y
*
hidden
+
2
)
*
-
1.
*
bs
;
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
hidden
-
3
;
j
++
)
{
(
dest
+
i
)
->
embedx_g
[
j
]
=
*
(
src
[
x
]
+
y
*
hidden
+
3
+
j
)
*
-
1.
*
bs
;
}
if
(
expand_dim
>
0
)
{
int
z
=
x
+
slot_num
;
for
(
int
j
=
0
;
j
<
expand_dim
;
j
++
)
{
(
dest
+
i
)
->
embed_expand_g
[
j
]
=
*
(
src
[
z
]
+
y
*
expand_dim
+
j
)
*
-
1.
*
bs
;
}
}
}
}
void
BoxWrapper
::
CopyForPull
(
const
paddle
::
platform
::
Place
&
place
,
uint64_t
**
gpu_keys
,
const
std
::
vector
<
float
*>&
values
,
const
boxps
::
FeatureValueGpu
*
total_values_gpu
,
const
int
64_t
*
gpu_len
,
const
int
slot_num
,
const
int
hidden_size
,
void
*
total_values_gpu
,
const
int64_t
*
gpu_len
,
const
int
slot_num
,
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int64_t
total_length
)
{
auto
stream
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
...
...
@@ -122,11 +147,40 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_value
->
ptr
());
cudaMemcpy
(
gpu_values
,
values
.
data
(),
values
.
size
()
*
sizeof
(
float
*
),
cudaMemcpyHostToDevice
);
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define EXPAND_EMBED_PULL_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PullCopy<EmbedxDim, \
ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
gpu_values, \
reinterpret_cast<boxps::FeatureValueGpu<EmbedxDim, ExpandDim>*>( \
total_values_gpu), \
gpu_len, hidden_size, expand_embed_dim, slot_num, total_length, \
gpu_keys); \
} break
PullCopy
<<<
(
total_length
+
512
-
1
)
/
512
,
512
,
0
,
stream
>>>
(
gpu_values
,
total_values_gpu
,
gpu_len
,
hidden_size
,
slot_num
,
total_length
,
gpu_keys
);
switch
(
hidden_size
-
3
)
{
EMBEDX_CASE
(
8
,
EXPAND_EMBED_PULL_CASE
(
0
);
EXPAND_EMBED_PULL_CASE
(
8
);
EXPAND_EMBED_PULL_CASE
(
64
););
EMBEDX_CASE
(
16
,
EXPAND_EMBED_PULL_CASE
(
0
););
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupport this embedding size [%d]"
,
hidden_size
-
3
));
}
cudaStreamSynchronize
(
stream
);
#undef EXPAND_EMBED_PULL_CASE
#undef EMBEDX_CASE
}
void
BoxWrapper
::
CopyKeys
(
const
paddle
::
platform
::
Place
&
place
,
...
...
@@ -143,10 +197,10 @@ void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
void
BoxWrapper
::
CopyForPush
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
float
*>&
grad_values
,
boxps
::
FeaturePushValueGpu
*
total_grad_values_gpu
,
void
*
total_grad_values_gpu
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
64_t
total_length
,
const
int
batch_size
)
{
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int
64_t
total_length
,
const
int
batch_size
)
{
auto
stream
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
)))
...
...
@@ -173,11 +227,42 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
cudaMemcpy
(
d_slot_vector
,
slot_vector_
.
data
(),
slot_lengths_lod
.
size
()
*
sizeof
(
int
),
cudaMemcpyHostToDevice
);
PushCopy
<<<
(
total_length
+
512
-
1
)
/
512
,
512
,
0
,
stream
>>>
(
total_grad_values_gpu
,
gpu_values
,
gpu_len
,
hidden_size
,
slot_lengths
.
size
(),
total_length
,
batch_size
,
d_slot_vector
);
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define EXPAND_EMBED_PUSH_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PushCopy<EmbedxDim, \
ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
reinterpret_cast<boxps::FeaturePushValueGpu<EmbedxDim, ExpandDim>*>( \
total_grad_values_gpu), \
gpu_values, gpu_len, hidden_size, expand_embed_dim, \
slot_lengths.size(), total_length, batch_size, d_slot_vector); \
} break
switch
(
hidden_size
-
3
)
{
EMBEDX_CASE
(
8
,
EXPAND_EMBED_PUSH_CASE
(
0
);
EXPAND_EMBED_PUSH_CASE
(
8
);
EXPAND_EMBED_PUSH_CASE
(
64
););
EMBEDX_CASE
(
16
,
EXPAND_EMBED_PUSH_CASE
(
0
););
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupport this embedding size [%d]"
,
hidden_size
-
3
));
}
cudaStreamSynchronize
(
stream
);
#undef EXPAND_EMBED_PUSH_CASE
#undef EMBEDX_CASE
}
}
// end namespace framework
}
// end namespace paddle
#endif
paddle/fluid/framework/fleet/box_wrapper.h
浏览文件 @
95089204
...
...
@@ -341,30 +341,54 @@ class BoxWrapper {
void
BeginPass
()
const
;
void
EndPass
(
bool
need_save_delta
)
const
;
void
SetTestMode
(
bool
is_test
)
const
;
template
<
size_t
EMBEDX_DIM
,
size_t
EXPAND_EMBED_DIM
=
0
>
void
PullSparseCase
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
expand_embed_dim
);
void
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
);
const
int
hidden_size
,
const
int
expand_embed_dim
);
template
<
size_t
EMBEDX_DIM
,
size_t
EXPAND_EMBED_DIM
=
0
>
void
PushSparseGradCase
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
const
float
*>&
grad_values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int
batch_size
);
void
PushSparseGrad
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
const
float
*>&
grad_values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
batch_size
);
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int
batch_size
);
void
CopyForPull
(
const
paddle
::
platform
::
Place
&
place
,
uint64_t
**
gpu_keys
,
const
std
::
vector
<
float
*>&
values
,
const
boxps
::
FeatureValueGpu
*
total_values_gpu
,
const
std
::
vector
<
float
*>&
values
,
void
*
total_values_gpu
,
const
int64_t
*
gpu_len
,
const
int
slot_num
,
const
int
hidden_size
,
const
int64_t
total_length
);
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int64_t
total_length
);
void
CopyForPush
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
float
*>&
grad_values
,
boxps
::
FeaturePushValueGpu
*
total_grad_values_gpu
,
void
*
total_grad_values_gpu
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int64_t
total_length
,
const
int
batch_size
);
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int64_t
total_length
,
const
int
batch_size
);
void
CopyKeys
(
const
paddle
::
platform
::
Place
&
place
,
uint64_t
**
origin_keys
,
uint64_t
*
total_keys
,
const
int64_t
*
gpu_len
,
int
slot_num
,
int
total_len
);
void
CheckEmbedSizeIsValid
(
int
embedx_dim
,
int
expand_embed_dim
);
boxps
::
PSAgentBase
*
GetAgent
()
{
return
p_agent_
;
}
void
InitializeGPUAndLoadModel
(
const
char
*
conf_file
,
const
std
::
vector
<
int
>&
slot_vector
,
...
...
@@ -442,6 +466,15 @@ class BoxWrapper {
}
static
std
::
shared_ptr
<
BoxWrapper
>
GetInstance
()
{
PADDLE_ENFORCE_EQ
(
s_instance_
==
nullptr
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"GetInstance failed in BoxPs, you should use SetInstance firstly"
));
return
s_instance_
;
}
static
std
::
shared_ptr
<
BoxWrapper
>
SetInstance
(
int
embedx_dim
=
8
,
int
expand_embed_dim
=
0
)
{
if
(
nullptr
==
s_instance_
)
{
// If main thread is guaranteed to init this, this lock can be removed
static
std
::
mutex
mutex
;
...
...
@@ -449,8 +482,13 @@ class BoxWrapper {
if
(
nullptr
==
s_instance_
)
{
VLOG
(
3
)
<<
"s_instance_ is null"
;
s_instance_
.
reset
(
new
paddle
::
framework
::
BoxWrapper
());
s_instance_
->
boxps_ptr_
.
reset
(
boxps
::
BoxPSBase
::
GetIns
());
s_instance_
->
boxps_ptr_
.
reset
(
boxps
::
BoxPSBase
::
GetIns
(
embedx_dim
,
expand_embed_dim
));
embedx_dim_
=
embedx_dim
;
expand_embed_dim_
=
expand_embed_dim
;
}
}
else
{
LOG
(
WARNING
)
<<
"You have already used SetInstance() before"
;
}
return
s_instance_
;
}
...
...
@@ -776,6 +814,9 @@ class BoxWrapper {
const
int
feedpass_thread_num_
=
30
;
// magic number
static
std
::
shared_ptr
<
BoxWrapper
>
s_instance_
;
std
::
unordered_set
<
std
::
string
>
slot_name_omited_in_feedpass_
;
// EMBEDX_DIM and EXPAND_EMBED_DIM
static
int
embedx_dim_
;
static
int
expand_embed_dim_
;
// Metric Related
int
phase_
=
1
;
...
...
@@ -1009,3 +1050,5 @@ class BoxHelper {
}
// end namespace framework
}
// end namespace paddle
#include "paddle/fluid/framework/fleet/box_wrapper_impl.h"
paddle/fluid/framework/fleet/box_wrapper_impl.h
0 → 100644
浏览文件 @
95089204
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_BOX_PS
#include <vector>
namespace
paddle
{
namespace
framework
{
template
<
size_t
EMBEDX_DIM
,
size_t
EXPAND_EMBED_DIM
>
void
BoxWrapper
::
PullSparseCase
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
expand_embed_dim
)
{
VLOG
(
3
)
<<
"Begin PullSparse"
;
platform
::
Timer
all_timer
;
platform
::
Timer
pull_boxps_timer
;
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
AllocShared
(
place
,
total_length
*
sizeof
(
boxps
::
FeatureValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>
));
boxps
::
FeatureValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*
total_values_gpu
=
reinterpret_cast
<
boxps
::
FeatureValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Warning:: CPUPlace is not supported in PaddleBox now."
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
VLOG
(
3
)
<<
"Begin copy keys, key_num["
<<
total_length
<<
"]"
;
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
GetDeviceId
();
LoDTensor
&
total_keys_tensor
=
keys_tensor
[
device_id
];
uint64_t
*
total_keys
=
reinterpret_cast
<
uint64_t
*>
(
total_keys_tensor
.
mutable_data
<
int64_t
>
({
total_length
,
1
},
place
));
// construct slot_level lod info
auto
slot_lengths_lod
=
slot_lengths
;
for
(
size_t
i
=
1
;
i
<
slot_lengths_lod
.
size
();
i
++
)
{
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_key
=
memory
::
AllocShared
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_length
=
memory
::
AllocShared
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
uint64_t
**
gpu_keys
=
reinterpret_cast
<
uint64_t
**>
(
buf_key
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
cudaMemcpy
(
gpu_keys
,
keys
.
data
(),
keys
.
size
()
*
sizeof
(
uint64_t
*
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
gpu_len
,
slot_lengths_lod
.
data
(),
slot_lengths
.
size
()
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
this
->
CopyKeys
(
place
,
gpu_keys
,
total_keys
,
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
static_cast
<
int
>
(
total_length
));
VLOG
(
3
)
<<
"Begin call PullSparseGPU in BoxPS"
;
pull_boxps_timer
.
Start
();
int
ret
=
boxps_ptr_
->
PullSparseGPU
(
total_keys
,
reinterpret_cast
<
void
*>
(
total_values_gpu
),
static_cast
<
int
>
(
total_length
),
device_id
);
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"PullSparseGPU failed in BoxPS."
));
pull_boxps_timer
.
Pause
();
VLOG
(
3
)
<<
"Begin Copy result to tensor, total_length["
<<
total_length
<<
"]"
;
this
->
CopyForPull
(
place
,
gpu_keys
,
values
,
reinterpret_cast
<
void
*>
(
total_values_gpu
),
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
hidden_size
,
expand_embed_dim
,
total_length
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddleBox: PullSparse Only Support CPUPlace or CUDAPlace Now."
));
}
all_timer
.
Pause
();
VLOG
(
1
)
<<
"PullSparse total costs: "
<<
all_timer
.
ElapsedSec
()
<<
" s, of which BoxPS costs: "
<<
pull_boxps_timer
.
ElapsedSec
()
<<
" s"
;
VLOG
(
3
)
<<
"End PullSparse"
;
}
template
<
size_t
EMBEDX_DIM
,
size_t
EXPAND_EMBED_DIM
>
void
BoxWrapper
::
PushSparseGradCase
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
const
float
*>&
grad_values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int
expand_embed_dim
,
const
int
batch_size
)
{
VLOG
(
3
)
<<
"Begin PushSparseGrad"
;
platform
::
Timer
all_timer
;
platform
::
Timer
push_boxps_timer
;
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
AllocShared
(
place
,
total_length
*
sizeof
(
boxps
::
FeaturePushValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>
));
boxps
::
FeaturePushValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*
total_grad_values_gpu
=
reinterpret_cast
<
boxps
::
FeaturePushValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Warning:: CPUPlace is not supported in PaddleBox now."
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
GetDeviceId
();
LoDTensor
&
cached_total_keys_tensor
=
keys_tensor
[
device_id
];
uint64_t
*
total_keys
=
reinterpret_cast
<
uint64_t
*>
(
cached_total_keys_tensor
.
data
<
int64_t
>
());
VLOG
(
3
)
<<
"Begin copy grad tensor to boxps struct"
;
this
->
CopyForPush
(
place
,
grad_values
,
total_grad_values_gpu
,
slot_lengths
,
hidden_size
,
expand_embed_dim
,
total_length
,
batch_size
);
VLOG
(
3
)
<<
"Begin call PushSparseGPU in BoxPS"
;
push_boxps_timer
.
Start
();
int
ret
=
boxps_ptr_
->
PushSparseGPU
(
total_keys
,
reinterpret_cast
<
void
*>
(
total_grad_values_gpu
),
static_cast
<
int
>
(
total_length
),
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
GetDeviceId
());
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"PushSparseGPU failed in BoxPS."
));
push_boxps_timer
.
Pause
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddleBox: PushSparseGrad Only Support CPUPlace or CUDAPlace Now."
));
}
all_timer
.
Pause
();
VLOG
(
1
)
<<
"PushSparseGrad total cost: "
<<
all_timer
.
ElapsedSec
()
<<
" s, of which BoxPS cost: "
<<
push_boxps_timer
.
ElapsedSec
()
<<
" s"
;
VLOG
(
3
)
<<
"End PushSparseGrad"
;
}
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/operators/pull_box_extended_sparse_op.cc
0 → 100644
浏览文件 @
95089204
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_extended_sparse_op.h"
namespace
paddle
{
namespace
operators
{
class
PullBoxExtendedSparseOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"Ids"
).
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Inputs(Ids) of PullBoxExtendedSparseOp should not be empty."
));
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Outputs(Out) of PullBoxExtendedSparseOp should not be empty."
));
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"OutExtend"
).
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Outputs(OutExtend) of PullBoxExtendedSparseOp "
"should not be empty."
));
auto
emb_size
=
static_cast
<
int64_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"emb_size"
));
auto
emb_extended_size
=
static_cast
<
int64_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"emb_extended_size"
));
auto
all_ids_dim
=
ctx
->
GetInputsDim
(
"Ids"
);
const
size_t
n_ids
=
all_ids_dim
.
size
();
std
::
vector
<
framework
::
DDim
>
outs_dims
;
std
::
vector
<
framework
::
DDim
>
outs_extended_dims
;
outs_dims
.
resize
(
n_ids
);
outs_extended_dims
.
resize
(
n_ids
);
for
(
size_t
i
=
0
;
i
<
n_ids
;
++
i
)
{
const
auto
ids_dims
=
all_ids_dim
[
i
];
int
ids_rank
=
ids_dims
.
size
();
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_rank
-
1
],
1
,
platform
::
errors
::
InvalidArgument
(
"Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1."
,
i
));
auto
out_dim
=
framework
::
vectorize
(
framework
::
slice_ddim
(
ids_dims
,
0
,
ids_rank
-
1
));
out_dim
.
push_back
(
emb_size
);
outs_dims
[
i
]
=
framework
::
make_ddim
(
out_dim
);
auto
out_extended_dim
=
framework
::
vectorize
(
framework
::
slice_ddim
(
ids_dims
,
0
,
ids_rank
-
1
));
out_extended_dim
.
push_back
(
emb_extended_size
);
outs_extended_dims
[
i
]
=
framework
::
make_ddim
(
out_extended_dim
);
}
ctx
->
SetOutputsDim
(
"Out"
,
outs_dims
);
ctx
->
SetOutputsDim
(
"OutExtend"
,
outs_extended_dims
);
for
(
size_t
i
=
0
;
i
<
n_ids
;
++
i
)
{
ctx
->
ShareLoD
(
"Ids"
,
"Out"
,
i
,
i
);
ctx
->
ShareLoD
(
"Ids"
,
"OutExtend"
,
i
,
i
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
FP32
,
ctx
.
device_context
());
}
};
class
PullBoxExtendedSparseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Ids"
,
"Input tensors with type int32 or int64 "
"contains the ids to be looked up in BoxPS. "
"The last dimension size must be 1."
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"The lookup results tensors."
).
AsDuplicable
();
AddOutput
(
"OutExtend"
,
"The lookup extended results tensors."
)
.
AsDuplicable
();
AddAttr
<
int
>
(
"emb_size"
,
"(int, the embedding hidden size"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"emb_extended_size"
,
"(int, the extended_embedding hidden size"
)
.
SetDefault
(
128
);
AddComment
(
R"DOC(
Pull Box Extended Sparse Operator.
This operator is used to perform lookups on the BoxPS,
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"
);
}
};
template
<
typename
T
>
class
PushBoxExtendedSparseOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"push_box_extended_sparse"
);
op
->
SetInput
(
"Ids"
,
this
->
Input
(
"Ids"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetInput
(
framework
::
GradVarName
(
"OutExtend"
),
this
->
OutputGrad
(
"OutExtend"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
};
class
PushBoxExtendedSparseOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
device_context
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
pull_box_extended_sparse
,
ops
::
PullBoxExtendedSparseOp
,
ops
::
PullBoxExtendedSparseOpMaker
,
ops
::
PushBoxExtendedSparseOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
PushBoxExtendedSparseOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
push_box_extended_sparse
,
ops
::
PushBoxExtendedSparseOp
);
REGISTER_OP_CPU_KERNEL
(
pull_box_extended_sparse
,
ops
::
PullBoxExtendedSparseCPUKernel
<
float
>
,
ops
::
PullBoxExtendedSparseCPUKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
push_box_extended_sparse
,
ops
::
PushBoxExtendedSparseCPUKernel
<
float
>
,
ops
::
PushBoxExtendedSparseCPUKernel
<
double
>
);
paddle/fluid/operators/pull_box_extended_sparse_op.cu
0 → 100644
浏览文件 @
95089204
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_extended_sparse_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
PullBoxExtendedSparseCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PullBoxExtendedSparseFunctor
<
T
>
(
ctx
);
}
};
template
<
typename
T
>
class
PushBoxExtendedSparseCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PushBoxExtendedSparseFunctor
<
T
>
(
ctx
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
pull_box_extended_sparse
,
ops
::
PullBoxExtendedSparseCUDAKernel
<
float
>
,
ops
::
PullBoxExtendedSparseCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
push_box_extended_sparse
,
ops
::
PushBoxExtendedSparseCUDAKernel
<
float
>
,
ops
::
PushBoxExtendedSparseCUDAKernel
<
double
>
);
paddle/fluid/operators/pull_box_extended_sparse_op.h
0 → 100644
浏览文件 @
95089204
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
static
void
PullBoxExtendedSparseFunctor
(
const
framework
::
ExecutionContext
&
ctx
)
{
auto
inputs
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Ids"
);
auto
outputs
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Out"
);
auto
outputs_extend
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"OutExtend"
);
const
auto
slot_size
=
inputs
.
size
();
std
::
vector
<
const
uint64_t
*>
all_keys
(
slot_size
);
// BoxPS only supports float now
std
::
vector
<
float
*>
all_values
(
slot_size
*
2
);
std
::
vector
<
int64_t
>
slot_lengths
(
slot_size
);
for
(
size_t
i
=
0
;
i
<
slot_size
;
i
++
)
{
const
auto
*
slot
=
inputs
[
i
];
const
uint64_t
*
single_slot_keys
=
reinterpret_cast
<
const
uint64_t
*>
(
slot
->
data
<
int64_t
>
());
all_keys
[
i
]
=
single_slot_keys
;
slot_lengths
[
i
]
=
slot
->
numel
();
auto
*
output
=
outputs
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
output_extend
=
outputs_extend
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
all_values
[
i
]
=
reinterpret_cast
<
float
*>
(
output
);
all_values
[
i
+
slot_size
]
=
reinterpret_cast
<
float
*>
(
output_extend
);
}
#ifdef PADDLE_WITH_BOX_PS
auto
emb_size
=
ctx
.
Attr
<
int
>
(
"emb_size"
);
auto
emb_extended_size
=
ctx
.
Attr
<
int
>
(
"emb_extended_size"
);
auto
box_ptr
=
paddle
::
framework
::
BoxWrapper
::
GetInstance
();
box_ptr
->
PullSparse
(
ctx
.
GetPlace
(),
all_keys
,
all_values
,
slot_lengths
,
emb_size
,
emb_extended_size
);
#endif
}
template
<
typename
T
>
static
void
PushBoxExtendedSparseFunctor
(
const
framework
::
ExecutionContext
&
ctx
)
{
auto
inputs
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Ids"
);
auto
d_output
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
d_output_extend
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"OutExtend"
));
const
auto
slot_size
=
inputs
.
size
();
std
::
vector
<
const
uint64_t
*>
all_keys
(
slot_size
);
std
::
vector
<
const
float
*>
all_grad_values
(
slot_size
*
2
);
std
::
vector
<
int64_t
>
slot_lengths
(
slot_size
);
int
batch_size
=
-
1
;
for
(
size_t
i
=
0
;
i
<
slot_size
;
i
++
)
{
const
auto
*
slot
=
inputs
[
i
];
const
uint64_t
*
single_slot_keys
=
reinterpret_cast
<
const
uint64_t
*>
(
slot
->
data
<
int64_t
>
());
all_keys
[
i
]
=
single_slot_keys
;
slot_lengths
[
i
]
=
slot
->
numel
();
int
cur_batch_size
=
slot
->
lod
().
size
()
?
slot
->
lod
()[
0
].
size
()
-
1
:
slot
->
dims
()[
0
];
if
(
batch_size
==
-
1
)
{
batch_size
=
cur_batch_size
;
}
else
{
PADDLE_ENFORCE_EQ
(
batch_size
,
cur_batch_size
,
platform
::
errors
::
PreconditionNotMet
(
"The batch size of all input slots should be same,"
"please cheack"
));
}
const
float
*
grad_value
=
d_output
[
i
]
->
data
<
float
>
();
const
float
*
grad_value_extend
=
d_output_extend
[
i
]
->
data
<
float
>
();
all_grad_values
[
i
]
=
reinterpret_cast
<
const
float
*>
(
grad_value
);
all_grad_values
[
i
+
slot_size
]
=
reinterpret_cast
<
const
float
*>
(
grad_value_extend
);
}
#ifdef PADDLE_WITH_BOX_PS
auto
emb_size
=
ctx
.
Attr
<
int
>
(
"emb_size"
);
auto
emb_extended_size
=
ctx
.
Attr
<
int
>
(
"emb_extended_size"
);
auto
box_ptr
=
paddle
::
framework
::
BoxWrapper
::
GetInstance
();
box_ptr
->
PushSparseGrad
(
ctx
.
GetPlace
(),
all_keys
,
all_grad_values
,
slot_lengths
,
emb_size
,
emb_extended_size
,
batch_size
);
#endif
}
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
class
PullBoxExtendedSparseCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PullBoxExtendedSparseFunctor
<
T
>
(
ctx
);
}
};
template
<
typename
T
>
class
PushBoxExtendedSparseCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PushBoxExtendedSparseFunctor
<
T
>
(
ctx
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/pull_box_sparse_op.h
浏览文件 @
95089204
...
...
@@ -44,7 +44,7 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto
hidden_size
=
ctx
.
Attr
<
int
>
(
"size"
);
auto
box_ptr
=
paddle
::
framework
::
BoxWrapper
::
GetInstance
();
box_ptr
->
PullSparse
(
ctx
.
GetPlace
(),
all_keys
,
all_values
,
slot_lengths
,
hidden_size
);
hidden_size
,
0
);
#endif
}
...
...
@@ -81,7 +81,7 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto
hidden_size
=
ctx
.
Attr
<
int
>
(
"size"
);
auto
box_ptr
=
paddle
::
framework
::
BoxWrapper
::
GetInstance
();
box_ptr
->
PushSparseGrad
(
ctx
.
GetPlace
(),
all_keys
,
all_grad_values
,
slot_lengths
,
hidden_size
,
batch_size
);
slot_lengths
,
hidden_size
,
0
,
batch_size
);
#endif
}
...
...
paddle/fluid/pybind/box_helper_py.cc
浏览文件 @
95089204
...
...
@@ -63,9 +63,9 @@ void BindBoxHelper(py::module* m) {
void
BindBoxWrapper
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
BoxWrapper
,
std
::
shared_ptr
<
framework
::
BoxWrapper
>>
(
*
m
,
"BoxWrapper"
)
.
def
(
py
::
init
([]()
{
.
def
(
py
::
init
([](
int
embedx_dim
,
int
expand_embed_dim
)
{
// return std::make_shared<paddle::framework::BoxHelper>(dataset);
return
framework
::
BoxWrapper
::
GetInstance
(
);
return
framework
::
BoxWrapper
::
SetInstance
(
embedx_dim
,
expand_embed_dim
);
}))
.
def
(
"save_base"
,
&
framework
::
BoxWrapper
::
SaveBase
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
...
...
python/paddle/fluid/contrib/layers/nn.py
浏览文件 @
95089204
...
...
@@ -34,7 +34,8 @@ __all__ = [
'fused_elemwise_activation'
,
'sequence_topk_avg_pooling'
,
'var_conv_2d'
,
'match_matrix_tensor'
,
'tree_conv'
,
'fused_embedding_seq_pool'
,
'multiclass_nms2'
,
'search_pyramid_hash'
,
'shuffle_batch'
,
'partial_concat'
,
'partial_sum'
,
'tdm_child'
,
'rank_attention'
,
'tdm_sampler'
,
'batch_fc'
'partial_sum'
,
'tdm_child'
,
'rank_attention'
,
'tdm_sampler'
,
'batch_fc'
,
'_pull_box_extended_sparse'
]
...
...
@@ -1361,3 +1362,50 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None):
"Bias"
:
b
},
outputs
=
{
"Out"
:
pre_act
})
return
helper
.
append_activation
(
pre_act
)
def
_pull_box_extended_sparse
(
input
,
size
,
extend_size
=
64
,
dtype
=
'float32'
):
"""
**Pull Box Extended Sparse Layer**
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
BoxPS lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable, which
contains the IDs information.
size(int): The embedding size parameter, which indicates the size of
each embedding vector respectively.
extend_size(int): The embedding size parameter in extended dim,
which indicates the size of each embedding vector respectively.
dtype(str): The dtype refers to the data type of output tensor. Only supports
float32 now.
Returns:
Variable|list of Variable: The tensor variable storing the embeddings of the
\
supplied inputs.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1)
emb, emb_ex = fluid.contrib.layers._pull_box_extended_sparse(input=data, size=8, extend_size=128)
"""
helper
=
LayerHelper
(
'pull_box_extended_sparse'
,
**
locals
())
helper
.
input_dtype
()
inputs
=
helper
.
multiple_input
()
outs
=
[
helper
.
create_variable_for_type_inference
(
dtype
)
for
i
in
range
(
len
(
inputs
))
]
outs_extend
=
[
helper
.
create_variable_for_type_inference
(
dtype
)
for
i
in
range
(
len
(
inputs
))
]
helper
.
append_op
(
type
=
'pull_box_extended_sparse'
,
inputs
=
{
'Ids'
:
inputs
},
outputs
=
{
'Out'
:
outs
,
'OutExtend'
:
outs_extend
},
attrs
=
{
'emb_size'
:
size
,
'emb_extended_size'
:
extend_size
})
if
len
(
outs
)
==
1
:
return
outs
[
0
],
outs_extend
[
0
]
return
outs
,
outs_extend
python/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py
浏览文件 @
95089204
...
...
@@ -17,7 +17,6 @@ import paddle.fluid.core as core
import
os
import
unittest
import
paddle.fluid.layers
as
layers
from
paddle.fluid.layers.nn
import
_pull_box_sparse
class
TestDataFeed
(
unittest
.
TestCase
):
...
...
@@ -57,9 +56,9 @@ class TestDataFeed(unittest.TestCase):
lod_level
=
0
,
append_batch_size
=
False
)
emb_x
,
emb_y
=
_pull_box_sparse
([
x
,
y
],
size
=
2
)
emb_xp
=
_pull_box_sparse
(
x
,
size
=
2
)
concat
=
layers
.
concat
([
emb_x
,
emb_y
],
axis
=
1
)
emb_x
,
emb_y
=
fluid
.
contrib
.
layers
.
_pull_box_extended_sparse
(
[
x
,
y
],
size
=
2
,
extend_size
=
128
)
concat
=
layers
.
concat
([
emb_x
[
0
],
emb_x
[
1
],
emb_y
[
0
],
emb_y
[
1
]
],
axis
=
1
)
fc
=
layers
.
fc
(
input
=
concat
,
name
=
"fc"
,
size
=
1
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录