Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a0c37662
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0c37662
编写于
2月 22, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable sgd jitkernel refer code and test
test=develop
上级
1dad36f6
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
211 addition
and
34 deletion
+211
-34
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+2
-1
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+1
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+8
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+23
-0
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+5
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+2
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+32
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+102
-3
paddle/fluid/operators/optimizers/sgd_op.h
paddle/fluid/operators/optimizers/sgd_op.h
+35
-30
未找到文件。
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
a0c37662
...
@@ -31,7 +31,8 @@ namespace gen {
...
@@ -31,7 +31,8 @@ namespace gen {
// Application Binary Interface
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
);
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_param5
(
Xbyak
::
Operand
::
R8
),
abi_param6
(
Xbyak
::
Operand
::
R9
);
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
a0c37662
...
@@ -55,6 +55,7 @@ const char* to_string(KernelType kt) {
...
@@ -55,6 +55,7 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kHSum
);
ONE_CASE
(
kHSum
);
ONE_CASE
(
kSoftmax
);
ONE_CASE
(
kSoftmax
);
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kSgd
);
default:
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
return
"NOT JITKernel"
;
return
"NOT JITKernel"
;
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
a0c37662
...
@@ -181,6 +181,14 @@ inline std::ostream& operator<<(std::ostream& os,
...
@@ -181,6 +181,14 @@ inline std::ostream& operator<<(std::ostream& os,
return
os
;
return
os
;
}
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
sgd_attr_t
&
attr
)
{
os
<<
"param_height["
<<
attr
.
param_height
<<
"],param_width["
<<
attr
.
param_width
<<
"],grad_height["
<<
attr
.
grad_height
<<
"],grad_width["
<<
attr
.
grad_width
<<
"],selected_rows_size["
<<
attr
.
selected_rows_size
<<
"]"
;
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
matmul_attr_t
&
attr
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
matmul_attr_t
&
attr
)
{
os
<<
"M["
<<
attr
.
m
<<
"],N["
<<
attr
.
n
<<
"],K["
<<
attr
.
k
<<
"]"
;
os
<<
"M["
<<
attr
.
m
<<
"],N["
<<
attr
.
n
<<
"],K["
<<
attr
.
k
<<
"]"
;
return
os
;
return
os
;
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
a0c37662
...
@@ -46,6 +46,7 @@ typedef enum {
...
@@ -46,6 +46,7 @@ typedef enum {
kVMul
,
kVMul
,
kVRelu
,
kVRelu
,
kVScal
,
kVScal
,
kSgd
,
kVSigmoid
,
kVSigmoid
,
kVSquare
,
kVSquare
,
kVSub
,
kVSub
,
...
@@ -173,6 +174,28 @@ struct EmbSeqPoolTuples {
...
@@ -173,6 +174,28 @@ struct EmbSeqPoolTuples {
const
emb_seq_pool_attr_t
*
);
const
emb_seq_pool_attr_t
*
);
};
};
typedef
struct
sgd_attr_s
{
int64_t
param_height
,
param_width
;
int64_t
grad_height
,
grad_width
;
int64_t
selected_rows_size
;
sgd_attr_s
()
=
default
;
explicit
sgd_attr_s
(
int64_t
param_h
,
int64_t
param_w
,
int64_t
grad_h
,
int64_t
grad_w
,
int64_t
selected_rows_sz
)
:
param_height
(
param_h
),
param_width
(
param_w
),
grad_height
(
grad_h
),
grad_width
(
grad_w
),
selected_rows_size
(
selected_rows_sz
)
{}
}
sgd_attr_t
;
template
<
typename
T
>
struct
SgdTuples
{
typedef
T
data_type
;
typedef
sgd_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
const
T
*
,
const
T
*
,
const
int64_t
*
,
T
*
,
const
sgd_attr_t
*
);
};
typedef
struct
matmul_attr_s
{
typedef
struct
matmul_attr_s
{
int
m
,
n
,
k
;
int
m
,
n
,
k
;
void
*
packed_weight
{
nullptr
};
void
*
packed_weight
{
nullptr
};
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
a0c37662
...
@@ -61,6 +61,11 @@ size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
...
@@ -61,6 +61,11 @@ size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return
attr
.
table_width
;
return
attr
.
table_width
;
}
}
template
<
>
size_t
JitCodeKey
<
sgd_attr_t
>
(
const
sgd_attr_t
&
attr
)
{
return
attr
.
grad_width
;
}
}
// namespace jit
}
// namespace jit
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
a0c37662
...
@@ -33,3 +33,4 @@ USE_JITKERNEL_REFER(kHSum)
...
@@ -33,3 +33,4 @@ USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER
(
kHMax
)
USE_JITKERNEL_REFER
(
kHMax
)
USE_JITKERNEL_REFER
(
kSoftmax
)
USE_JITKERNEL_REFER
(
kSoftmax
)
USE_JITKERNEL_REFER
(
kEmbSeqPool
)
USE_JITKERNEL_REFER
(
kEmbSeqPool
)
USE_JITKERNEL_REFER
(
kSgd
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
a0c37662
...
@@ -59,4 +59,6 @@ REGISTER_REFER_KERNEL(kSoftmax, Softmax);
...
@@ -59,4 +59,6 @@ REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL
(
kEmbSeqPool
,
EmbSeqPool
);
REGISTER_REFER_KERNEL
(
kEmbSeqPool
,
EmbSeqPool
);
REGISTER_REFER_KERNEL
(
kSgd
,
Sgd
);
#undef REGISTER_REFER_KERNEL
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
a0c37662
...
@@ -446,6 +446,36 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
...
@@ -446,6 +446,36 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
}
}
}
}
// SGD algorithm:
// lr is pointor of learning rate scalar
// param is an input matrix with (param_h, param_w)
// grad is an input matrix with (grad_h, grad_w), here grad_w == param_w
// selected_rows is a vectot<int64_t> with size selected_rows_size( <= grad_h )
// out is an output matrix with (param_h, param_w)
//
// support both regular and sparse grad
// regular SGD: out[:] = param[:] - lr[0] * grad[:];
// sparse SGD: out[rows[i]][:] = param[rows[i]][:] - lr[0] * grad[i][:]
//
// Note: when use sparse SGD, and if out != param,
// the out rows which are not selected have not beed changed, which maybe empty
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
for
(
int64_t
j
=
0
;
j
<
attr
->
grad_width
;
++
j
)
{
out
[
h_idx
*
attr
->
grad_width
+
j
]
=
param
[
h_idx
*
attr
->
grad_width
+
j
]
-
lr
[
0
]
*
grad
[
i
*
attr
->
grad_width
+
j
];
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
class name##Kernel : public ReferKernel<tuples<T>> { \
...
@@ -496,6 +526,8 @@ DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
...
@@ -496,6 +526,8 @@ DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
DECLARE_REFER_KERNEL
(
EmbSeqPool
,
EmbSeqPoolTuples
);
DECLARE_REFER_KERNEL
(
EmbSeqPool
,
EmbSeqPoolTuples
);
DECLARE_REFER_KERNEL
(
Sgd
,
SgdTuples
);
#undef DECLARE_REFER_KERNEL
#undef DECLARE_REFER_KERNEL
}
// namespace refer
}
// namespace refer
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
a0c37662
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <random>
#include <random>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -36,13 +37,13 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
...
@@ -36,13 +37,13 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
}
}
template
<
typename
T
>
template
<
typename
T
>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
in
t
n
)
{
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
size_
t
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
in
t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
FLAGS_acc
);
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
FLAGS_acc
);
}
}
}
else
{
}
else
{
for
(
in
t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
}
}
}
}
...
@@ -296,6 +297,45 @@ struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>,
...
@@ -296,6 +297,45 @@ struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>,
}
}
};
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
SgdTuples
<
T
>
,
T
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
T
>
,
typename
jit
::
SgdTuples
<
T
>::
attr_type
>
{
void
operator
()(
const
typename
jit
::
SgdTuples
<
T
>::
func_type
tgt
,
const
T
lr
,
const
std
::
vector
<
T
>&
param
,
const
std
::
vector
<
T
>&
grad
,
const
std
::
vector
<
int64_t
>&
rows
,
const
std
::
vector
<
T
>&
oref
,
const
typename
jit
::
SgdTuples
<
T
>::
attr_type
&
attr
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
param
.
size
(),
static_cast
<
size_t
>
(
attr
.
param_height
*
attr
.
param_width
));
EXPECT_EQ
(
grad
.
size
(),
static_cast
<
size_t
>
(
attr
.
grad_height
*
attr
.
grad_width
));
EXPECT_EQ
(
rows
.
size
(),
static_cast
<
size_t
>
(
attr
.
selected_rows_size
));
EXPECT_EQ
(
param
.
size
(),
oref
.
size
());
const
T
*
param_data
=
param
.
data
();
const
T
*
grad_data
=
grad
.
data
();
const
int64_t
*
rows_data
=
rows
.
data
();
const
T
*
oref_data
=
oref
.
data
();
std
::
vector
<
T
>
out
(
oref
.
size
());
T
*
o_data
=
out
.
data
();
tgt
(
&
lr
,
param_data
,
grad_data
,
rows_data
,
o_data
,
&
attr
);
// only the selected rows should be equal
for
(
size_t
i
=
0
;
i
<
rows
.
size
();
++
i
)
{
ExpectEQ
<
T
>
(
o_data
+
rows
[
i
]
*
attr
.
grad_width
,
oref_data
+
rows
[
i
]
*
attr
.
grad_width
,
attr
.
grad_width
);
}
// inplace
std
::
copy
(
param
.
begin
(),
param
.
end
(),
out
.
begin
());
tgt
(
&
lr
,
o_data
,
grad_data
,
rows_data
,
o_data
,
&
attr
);
for
(
size_t
i
=
0
;
i
<
rows
.
size
();
++
i
)
{
ExpectEQ
<
T
>
(
o_data
+
rows
[
i
]
*
attr
.
grad_width
,
oref_data
+
rows
[
i
]
*
attr
.
grad_width
,
attr
.
grad_width
);
}
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
MatMulTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
struct
TestFuncWithRefer
<
jit
::
MatMulTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
...
@@ -704,6 +744,60 @@ void TestEmbSeqPoolKernel() {
...
@@ -704,6 +744,60 @@ void TestEmbSeqPoolKernel() {
}
}
}
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestSgdKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_GT
(
n
,
0
);
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
}
std
::
random_shuffle
(
all
.
begin
(),
all
.
end
());
out
.
insert
(
out
.
begin
(),
all
.
begin
(),
all
.
begin
()
+
n
);
return
out
;
};
for
(
int
param_h
:
{
1
,
10
})
{
for
(
int
grad_w
:
TestSizes
())
{
std
::
vector
<
T
>
param
(
param_h
*
grad_w
);
std
::
vector
<
T
>
param_out
(
param_h
*
grad_w
);
RandomVec
<
T
>
(
param_h
*
grad_w
,
param
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
param_data
=
param
.
data
();
T
*
out_data
=
param_out
.
data
();
for
(
int
rows_size
=
1
;
rows_size
<=
param_h
;
++
rows_size
)
{
std
::
vector
<
T
>
grad
(
rows_size
*
grad_w
);
std
::
vector
<
int64_t
>
rows
=
UnDuplicatedRandomVec
(
rows_size
,
0
,
rows_size
-
1
);
RandomVec
<
T
>
(
rows_size
*
grad_w
,
grad
.
data
(),
-
2.
f
,
2.
f
);
const
int64_t
*
rows_data
=
rows
.
data
();
const
T
*
grad_data
=
grad
.
data
();
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
SgdTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
jit
::
sgd_attr_t
attr
(
param_h
,
grad_w
,
rows_size
,
grad_w
,
rows_size
);
ref
(
&
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
// inplace test
std
::
vector
<
T
>
inp
(
param
.
size
());
std
::
copy
(
param
.
begin
(),
param
.
end
(),
inp
.
begin
());
T
*
inp_data
=
inp
.
data
();
ref
(
&
lr
,
inp_data
,
grad_data
,
rows_data
,
inp_data
,
&
attr
);
// only the selected rows should be equal
for
(
int
i
=
0
;
i
<
rows_size
;
++
i
)
{
ExpectEQ
<
T
>
(
inp_data
+
rows
[
i
]
*
grad_w
,
out_data
+
rows
[
i
]
*
grad_w
,
grad_w
);
}
TestAllImpls
<
KT
,
jit
::
SgdTuples
<
T
>
,
PlaceType
,
T
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
T
>>
(
attr
,
lr
,
param
,
grad
,
rows
,
param_out
,
attr
);
}
}
}
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestNCHW16CMulNCKernel
()
{
void
TestNCHW16CMulNCKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
...
@@ -943,6 +1037,11 @@ TEST(JITKernel, kEmbSeqPool) {
...
@@ -943,6 +1037,11 @@ TEST(JITKernel, kEmbSeqPool) {
TestEmbSeqPoolKernel
<
jit
::
kEmbSeqPool
,
double
,
CPUPlace
>
();
TestEmbSeqPoolKernel
<
jit
::
kEmbSeqPool
,
double
,
CPUPlace
>
();
}
}
TEST
(
JITKernel
,
kSgd
)
{
TestSgdKernel
<
jit
::
kSgd
,
float
,
CPUPlace
>
();
TestSgdKernel
<
jit
::
kSgd
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
CPUPlace
>
();
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
CPUPlace
>
();
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
double
,
CPUPlace
>
();
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
double
,
CPUPlace
>
();
...
...
paddle/fluid/operators/optimizers/sgd_op.h
浏览文件 @
a0c37662
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -32,53 +33,57 @@ class SGDOpKernel : public framework::OpKernel<T> {
...
@@ -32,53 +33,57 @@ class SGDOpKernel : public framework::OpKernel<T> {
if
(
param_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
param_var
->
IsType
<
framework
::
LoDTensor
>
())
{
const
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
const
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
// Actually, all tensors are LoDTensor except SelectedRows.
// Actually, all tensors are LoDTensor except SelectedRows.
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
*
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
const
auto
*
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
sz
=
param_out
->
numel
();
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
PADDLE_ENFORCE_EQ
(
param
->
numel
(),
sz
);
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
PADDLE_ENFORCE_EQ
(
grad
->
numel
(),
sz
);
auto
o
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
*
lr
=
learning_rate
->
data
<
T
>
();
jit
::
sgd_attr_t
attr
(
1
,
sz
,
1
,
sz
,
1
);
const
T
*
lr
=
learning_rate
->
data
<
T
>
();
o
=
p
-
lr
[
0
]
*
g
;
const
T
*
param_data
=
param
->
data
<
T
>
();
const
T
*
grad_data
=
grad
->
data
<
T
>
();
int64_t
rows_idx
=
0
;
T
*
out_data
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
sgd
=
jit
::
Get
<
jit
::
kSgd
,
jit
::
SgdTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
sgd
(
lr
,
param_data
,
grad_data
,
&
rows_idx
,
out_data
,
&
attr
);
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ
(
param
,
param_out
);
PADDLE_ENFORCE_EQ
(
param
,
param_out
);
const
auto
*
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
const
auto
*
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
auto
&
grad_rows
=
grad
->
rows
();
// for distributed training, a sparse var may be empty,
// for distributed training, a sparse var may be empty,
// just skip updating.
// just skip updating.
if
(
grad
->
rows
()
.
size
()
==
0
)
{
if
(
grad
_rows
.
size
()
==
0
)
{
return
;
return
;
}
}
auto
grad_height
=
grad
->
height
();
auto
out_dims
=
param_out
->
dims
();
auto
out_dims
=
param_out
->
dims
();
PADDLE_ENFORCE_EQ
(
grad_height
,
out_dims
[
0
]);
PADDLE_ENFORCE_EQ
(
grad
->
height
(),
out_dims
[
0
]);
auto
&
grad_value
=
grad
->
value
();
auto
&
grad_value
=
grad
->
value
();
auto
&
grad_rows
=
grad
->
rows
();
const
T
*
param_data
=
param
->
data
<
T
>
();
const
T
*
grad_data
=
grad_value
.
data
<
T
>
();
size_t
grad_row_numel
=
grad_value
.
numel
()
/
grad_rows
.
size
();
const
T
*
lr
=
learning_rate
->
data
<
T
>
();
PADDLE_ENFORCE_EQ
(
static_cast
<
int64_t
>
(
grad_row_numel
),
const
int64_t
*
rows_data
=
grad_rows
.
data
();
param_out
->
numel
()
/
grad_height
);
T
*
out_data
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()
);
auto
*
grad_data
=
grad_value
.
data
<
T
>
()
;
jit
::
sgd_attr_t
attr
;
a
uto
*
out_data
=
param_out
->
data
<
T
>
()
;
a
ttr
.
param_height
=
out_dims
[
0
]
;
a
uto
*
lr
=
learning_rate
->
data
<
T
>
()
;
a
ttr
.
param_width
=
param_out
->
numel
()
/
attr
.
param_height
;
for
(
size_t
i
=
0
;
i
<
grad_rows
.
size
();
i
++
)
{
attr
.
grad_height
=
grad_rows
.
size
();
// note: it is not grad->height()
PADDLE_ENFORCE
(
grad_rows
[
i
]
<
grad_height
,
attr
.
grad_width
=
grad_value
.
numel
()
/
attr
.
grad_height
;
"Input rows index should less than height"
);
attr
.
selected_rows_size
=
grad_rows
.
size
(
);
for
(
size_t
j
=
0
;
j
<
grad_row_numel
;
j
++
)
{
PADDLE_ENFORCE_EQ
(
attr
.
grad_width
,
attr
.
param_width
);
out_data
[
grad_rows
[
i
]
*
grad_row_numel
+
j
]
-=
lr
[
0
]
*
grad_data
[
i
*
grad_row_numel
+
j
];
auto
sgd
=
}
jit
::
Get
<
jit
::
kSgd
,
jit
::
SgdTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
}
sgd
(
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录