Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
802f362a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
802f362a
编写于
3月 07, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
unify the kernelfuncs cache and add unit test
test=develop
上级
36e2d324
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
158 addition
and
106 deletion
+158
-106
paddle/fluid/operators/crf_decoding_op.h
paddle/fluid/operators/crf_decoding_op.h
+3
-2
paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
+4
-2
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
+8
-4
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+26
-23
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+28
-26
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
+6
-4
paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc
paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc
+3
-3
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
+18
-14
paddle/fluid/operators/jit/CMakeLists.txt
paddle/fluid/operators/jit/CMakeLists.txt
+1
-1
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+1
-1
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+25
-9
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+23
-7
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+3
-3
paddle/fluid/operators/math/sequence_pooling.cc
paddle/fluid/operators/math/sequence_pooling.cc
+3
-3
paddle/fluid/operators/optimizers/sgd_op.h
paddle/fluid/operators/optimizers/sgd_op.h
+6
-4
未找到文件。
paddle/fluid/operators/crf_decoding_op.h
浏览文件 @
802f362a
...
...
@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor
track
;
int
*
track_value
=
track
.
mutable_data
<
int
>
(
emission_dims
,
platform
::
CPUPlace
());
auto
ker
=
jit
::
Get
<
jit
::
kCRFDecoding
,
jit
::
CRFDecodingTuples
<
T
>
,
platform
::
CPUPlace
>
(
tag_num
);
auto
ker
=
jit
::
KernelFuncs
<
jit
::
kCRFDecoding
,
jit
::
CRFDecodingTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
tag_num
);
ker
(
static_cast
<
int
>
(
seq_len
),
x
,
w
,
alpha_value
,
track_value
,
tag_num
);
T
max_score
=
-
std
::
numeric_limits
<
T
>::
max
();
int
max_i
=
0
;
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
浏览文件 @
802f362a
...
...
@@ -110,8 +110,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr
int
simd_width
=
16
;
int
C
=
c
/
simd_width
;
auto
multiply
=
jit
::
Get
<
jit
::
kNCHW16CMulNC
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
platform
::
CPUPlace
>
(
0
);
auto
multiply
=
jit
::
KernelFuncs
<
jit
::
kNCHW16CMulNC
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
0
);
#pragma omp parallel for collapse(2)
for
(
int
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
int
ci
=
0
;
ci
<
C
;
ci
++
)
{
...
...
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
浏览文件 @
802f362a
...
...
@@ -52,8 +52,10 @@ struct EmbeddingVSumFunctor {
out_width
,
jit
::
SeqPoolType
::
kSum
);
for
(
size_t
i
=
0
;
i
!=
ids_lod
.
size
()
-
1
;
++
i
)
{
attr
.
index_height
=
ids_lod
[
i
+
1
]
-
ids_lod
[
i
];
auto
emb_seqpool
=
jit
::
Get
<
jit
::
kEmbSeqPool
,
jit
::
EmbSeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
emb_seqpool
=
jit
::
KernelFuncs
<
jit
::
kEmbSeqPool
,
jit
::
EmbSeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
emb_seqpool
(
table
,
ids
+
ids_lod
[
i
]
*
idx_width
,
output
+
i
*
out_width
,
&
attr
);
}
...
...
@@ -135,8 +137,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T
*
d_table_data
=
d_table_value
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
d_output_data
=
d_output
->
data
<
T
>
();
auto
vbroadcast
=
jit
::
Get
<
jit
::
kVBroadcast
,
jit
::
VBroadcastTuples
<
T
>
,
platform
::
CPUPlace
>
(
out_width
);
auto
vbroadcast
=
jit
::
KernelFuncs
<
jit
::
kVBroadcast
,
jit
::
VBroadcastTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
out_width
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
int64_t
h
=
static_cast
<
int64_t
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
const
T
*
src
=
d_output_data
+
i
*
out_width
;
...
...
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
802f362a
...
...
@@ -195,12 +195,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart1 = \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
...
...
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
802f362a
...
...
@@ -257,10 +257,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
jit::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
auto ComputeC1H1 = jit::KernelFuncs<jit::kLSTMC1H1, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeCtHt = jit::KernelFuncs<jit::kLSTMCtHt, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
...
...
paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
浏览文件 @
802f362a
...
...
@@ -81,10 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() {
template
<
typename
T
>
static
void
fc_relu
(
const
T
*
x
,
const
T
*
w
,
const
T
*
b
,
T
*
y
,
const
jit
::
matmul_attr_t
&
attr
)
{
auto
matmul
=
jit
::
Get
<
jit
::
kMatMul
,
jit
::
MatMulTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
addbias_relu
=
jit
::
Get
<
jit
::
kVAddRelu
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
.
n
);
auto
matmul
=
jit
::
KernelFuncs
<
jit
::
kMatMul
,
jit
::
MatMulTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
auto
addbias_relu
=
jit
::
KernelFuncs
<
jit
::
kVAddRelu
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
.
n
);
matmul
(
x
,
w
,
y
,
&
attr
);
T
*
dst
=
y
;
for
(
int
i
=
0
;
i
<
attr
.
m
;
++
i
)
{
...
...
paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc
浏览文件 @
802f362a
...
...
@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
}
else
if
(
pooltype
==
"SQRT"
)
{
attr
.
type
=
jit
::
SeqPoolType
::
kSqrt
;
}
auto
seqpool
=
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
seqpool
=
jit
::
KernelFuncs
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
size_t
n
=
ins
.
size
();
size_t
dst_step_size
=
n
*
w
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
...
...
paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
浏览文件 @
802f362a
...
...
@@ -93,20 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
attr
.
n
=
y_dims
[
1
];
int
o_numel
=
attr
.
m
*
attr
.
n
;
auto
vsquare_x
=
jit
::
Get
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
.
m
*
attr
.
k
);
auto
vsquare_y
=
jit
::
Get
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
.
k
*
attr
.
n
);
auto
vsquare_xy
=
jit
::
Get
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
o_numel
);
auto
vsub
=
jit
::
Get
<
jit
::
kVSub
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
o_numel
);
auto
vscal
=
jit
::
Get
<
jit
::
kVScal
,
jit
::
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
o_numel
);
auto
matmul
=
jit
::
Get
<
jit
::
kMatMul
,
jit
::
MatMulTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
vsquare_x
=
jit
::
KernelFuncs
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
.
m
*
attr
.
k
);
auto
vsquare_y
=
jit
::
KernelFuncs
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
.
k
*
attr
.
n
);
auto
vsquare_xy
=
jit
::
KernelFuncs
<
jit
::
kVSquare
,
jit
::
XYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
o_numel
);
auto
vsub
=
jit
::
KernelFuncs
<
jit
::
kVSub
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
o_numel
);
auto
vscal
=
jit
::
KernelFuncs
<
jit
::
kVScal
,
jit
::
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
o_numel
);
auto
matmul
=
jit
::
KernelFuncs
<
jit
::
kMatMul
,
jit
::
MatMulTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
y_data
=
y
->
data
<
T
>
();
...
...
paddle/fluid/operators/jit/CMakeLists.txt
浏览文件 @
802f362a
...
...
@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
file
(
APPEND
${
jit_file
}
"
\#
include
\"
paddle/fluid/operators/jit/helper.h
\"\n
"
)
file
(
APPEND
${
jit_file
}
"
\#
include
\"
paddle/fluid/operators/jit/registry.h
\"\n\n
"
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
xxhash
)
file
(
GLOB jit_kernel_cc_srcs RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.cc"
)
list
(
REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc
)
...
...
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
802f362a
...
...
@@ -142,7 +142,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
}
// Test result from Get function
auto
tgt
=
jit
::
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
auto
tgt
=
jit
::
KernelFuncs
<
KT
,
KernelTuples
,
PlaceType
>::
Cache
().
At
(
attr
);
if
(
!
tgt
)
{
LOG
(
FATAL
)
<<
"Target can not be empty!"
;
}
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
802f362a
...
...
@@ -14,6 +14,9 @@
#pragma once
extern
"C"
{
#include <xxhash.h>
}
#include <iostream>
#include <string>
#include <vector>
...
...
@@ -127,23 +130,36 @@ class KernelFuncs {
return
g_func_cache
;
}
bool
Has
(
int
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
void
Insert
(
int
key
,
typename
KernelTuples
::
func_type
func
)
{
funcs_
.
emplace
(
key
,
func
);
}
typename
KernelTuples
::
func_type
At
(
int
key
)
{
// the exposed interface to use
typename
KernelTuples
::
func_type
At
(
const
typename
KernelTuples
::
attr_type
&
attr
)
{
// XXH64: 13.8 GB/s
int64_t
key
=
XXH64
(
&
attr
,
sizeof
(
typename
KernelTuples
::
attr_type
),
0
);
if
(
Has
(
key
))
{
return
funcs_
.
at
(
key
);
}
auto
func
=
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
key
);
// If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one.
auto
func
=
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
Insert
(
key
,
func
);
return
func
;
}
typename
KernelTuples
::
func_type
operator
[](
const
typename
KernelTuples
::
attr_type
&
attr
)
{
return
At
(
attr
);
}
protected:
bool
Has
(
int64_t
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
void
Insert
(
int64_t
key
,
typename
KernelTuples
::
func_type
func
)
{
funcs_
.
emplace
(
key
,
func
);
}
private:
std
::
unordered_map
<
int
,
typename
KernelTuples
::
func_type
>
funcs_
;
std
::
unordered_map
<
int
64_t
,
typename
KernelTuples
::
func_type
>
funcs_
;
DISABLE_COPY_AND_ASSIGN
(
KernelFuncs
);
};
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
802f362a
...
...
@@ -462,7 +462,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
// test result from Get function
// VLOG(10) << "Test Get function ";
auto
tgt
=
jit
::
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
auto
tgt
=
jit
::
KernelFuncs
<
KT
,
KernelTuples
,
PlaceType
>::
Cache
().
At
(
attr
);
test
(
tgt
,
args
...);
}
...
...
@@ -845,7 +845,9 @@ void TestKernelNCHW16CMulNCTuples() {
T
*
zjit_data
=
zjit
.
data
();
constexpr
int
simd_width
=
ZMM_FLOAT_BLOCK
;
int
C
=
c
/
simd_width
;
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
PlaceType
>
(
0
);
auto
tgt
=
jit
::
KernelFuncs
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
PlaceType
>::
Cache
().
At
(
0
);
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
PlaceType
>
(
0
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
...
...
@@ -970,7 +972,7 @@ void TestKernelVBroadcastTuples() {
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type,
float
, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type,
double
, CPUPlace>(); \
}
TEST_CPU_KERNEL
(
XYZNTuples
,
kVMul
);
...
...
@@ -1041,4 +1043,18 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
// TODO(TJ): add more test about key and pool
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
float
>
,
CPUPlace
>::
Cache
()
.
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
==
f2
);
f1
=
jit
::
KernelFuncs
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
float
>
,
CPUPlace
>::
Cache
()
.
At
(
3
);
f2
=
jit
::
KernelFuncs
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
float
>
,
CPUPlace
>::
Cache
()
.
At
(
4
);
EXPECT_TRUE
(
f1
!=
f2
);
}
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
802f362a
...
...
@@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
scale
->
numel
(),
right
);
PADDLE_ENFORCE_EQ
(
bias
->
numel
(),
right
);
auto
ker
=
jit
::
Get
<
jit
::
kLayerNorm
,
jit
::
LayerNormTuples
<
T
>
,
platform
::
CPUPlace
>
(
right
);
auto
ker
=
jit
::
KernelFuncs
<
jit
::
kLayerNorm
,
jit
::
LayerNormTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
right
);
ker
(
x
.
data
<
T
>
(),
out
.
data
<
T
>
(),
mean
->
data
<
T
>
(),
var
->
data
<
T
>
(),
scale
->
data
<
T
>
(),
bias
->
data
<
T
>
(),
static_cast
<
int
>
(
left
),
static_cast
<
const
float
>
(
epsilon
),
right
);
...
...
paddle/fluid/operators/math/sequence_pooling.cc
浏览文件 @
802f362a
...
...
@@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit
::
seq_pool_attr_t
attr
(
static_cast
<
int
>
(
input
.
numel
()
/
input
.
dims
()[
0
]),
jit
::
SeqPoolType
::
kSum
);
auto
seqpool
=
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
seqpool
=
jit
::
KernelFuncs
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
attr
.
h
=
static_cast
<
int
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
seqpool
(
src
,
dst
,
&
attr
);
...
...
paddle/fluid/operators/optimizers/sgd_op.h
浏览文件 @
802f362a
...
...
@@ -47,8 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
int64_t
rows_idx
=
0
;
T
*
out_data
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
sgd
=
jit
::
Get
<
jit
::
kSgd
,
jit
::
SgdTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
sgd
=
jit
::
KernelFuncs
<
jit
::
kSgd
,
jit
::
SgdTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
sgd
(
lr
,
param_data
,
grad_data
,
&
rows_idx
,
out_data
,
&
attr
);
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
...
...
@@ -81,8 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
attr
.
selected_rows_size
=
grad_rows
.
size
();
PADDLE_ENFORCE_EQ
(
attr
.
grad_width
,
attr
.
param_width
);
auto
sgd
=
jit
::
Get
<
jit
::
kSgd
,
jit
::
SgdTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
auto
sgd
=
jit
::
KernelFuncs
<
jit
::
kSgd
,
jit
::
SgdTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
sgd
(
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录