Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c69d2bbe
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c69d2bbe
编写于
10月 26, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add base impl
上级
c26f2b21
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
365 addition
and
0 deletion
+365
-0
paddle/fluid/operators/fused_embedding_seq_pool_op.cc
paddle/fluid/operators/fused_embedding_seq_pool_op.cc
+158
-0
paddle/fluid/operators/fused_embedding_seq_pool_op.h
paddle/fluid/operators/fused_embedding_seq_pool_op.h
+207
-0
未找到文件。
paddle/fluid/operators/fused_embedding_seq_pool_op.cc
0 → 100644
浏览文件 @
c69d2bbe
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_embedding_seq_pool_op.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace
paddle
{
namespace
operators
{
class
LookupTableOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) of LookupTableOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of LookupTableOp should not be null."
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
int
ids_rank
=
ids_dims
.
size
();
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_rank
-
1
],
1
,
"The last dimension of the 'Ids' tensor must be 1."
);
auto
output_dims
=
framework
::
vectorize
(
framework
::
slice_ddim
(
ids_dims
,
0
,
ids_rank
-
1
));
output_dims
.
push_back
(
table_dims
[
1
]);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_dims
));
if
(
ctx
->
GetOutputsVarType
(
"Out"
)[
0
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"W"
));
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
class
LookupTableOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"W"
,
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter."
);
AddInput
(
"Ids"
,
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"The last dimension size must be 1."
);
AddOutput
(
"Out"
,
"The lookup results, which have the same type as W."
);
AddAttr
<
bool
>
(
"is_sparse"
,
"(boolean, default false) "
"Sparse update."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"is_distributed"
,
"(boolean, default false) distributed lookup table."
)
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids."
)
.
SetDefault
(
kNoPadding
);
AddComment
(
R"DOC(
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"
);
}
};
class
LookupTableOpGradDescMaker
:
public
framework
::
DefaultGradOpDescMaker
<
true
>
{
using
::
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>::
DefaultGradOpDescMaker
;
protected:
virtual
std
::
string
GradOpType
()
const
{
return
"lookup_table_grad"
;
}
};
class
LookupTableOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"W"
));
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
attr
=
op_desc
.
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
block
->
Var
(
out_var_name
)
->
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
)
<<
"lookup_table_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
block
->
Var
(
out_var_name
)
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
block
->
Var
(
out_var_name
)
->
SetDataType
(
block
->
Var
(
"W"
)
->
GetDataType
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
lookup_table
,
ops
::
LookupTableOp
,
ops
::
LookupTableOpGradDescMaker
,
ops
::
LookupTableOpMaker
);
REGISTER_OPERATOR
(
lookup_table_grad
,
ops
::
LookupTableOpGrad
,
ops
::
LookupTableOpGradVarTypeInference
);
// REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
// ops::LookupTableKernel<double>);
// REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
// ops::LookupTableGradKernel<double>);
REGISTER_OP_CPU_KERNEL
(
lookup_table
,
ops
::
LookupTableKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradKernel
<
float
>
);
paddle/fluid/operators/fused_embedding_seq_pool_op.h
0 → 100644
浏览文件 @
c69d2bbe
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
SelectedRows
=
framework
::
SelectedRows
;
using
DDim
=
framework
::
DDim
;
constexpr
int64_t
kNoPadding
=
-
1
;
template
<
typename
T
>
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
// int tensor
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
table_var
=
context
.
InputVar
(
"W"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
*
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
int64_t
ids_numel
=
ids_t
->
numel
();
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
int64_t
row_number
=
table_t
->
dims
()[
0
];
int64_t
row_width
=
table_t
->
dims
()[
1
];
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
output
+
i
*
row_width
,
table
+
ids
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
}
}
else
if
(
table_var
->
IsType
<
SelectedRows
>
())
{
const
auto
&
table_t
=
table_var
->
Get
<
SelectedRows
>
();
int64_t
row_width
=
table_t
.
value
().
dims
()[
1
];
const
auto
*
table
=
table_t
.
value
().
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
auto
id_index
=
table_t
.
Index
(
ids
[
i
]);
PADDLE_ENFORCE_GE
(
id_index
,
0
,
"the input key should be exists."
);
// memcpy(output + i * row_width, table + id_index * row_width,
// row_width * sizeof(T));
blas
.
VCOPY
(
row_width
,
table
+
id_index
*
row_width
,
output
+
i
*
row_width
);
}
}
}
}
};
template
<
typename
T
>
class
LookupTableGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
table_var
=
context
.
InputVar
(
"W"
);
DDim
table_dim
;
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
table_dim
=
context
.
Input
<
LoDTensor
>
(
"W"
)
->
dims
();
}
else
if
(
table_var
->
IsType
<
SelectedRows
>
())
{
auto
*
table_t
=
context
.
Input
<
SelectedRows
>
(
"W"
);
table_dim
=
table_t
->
value
().
dims
();
}
else
{
PADDLE_THROW
(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows"
);
}
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
// auto start = std::chrono::system_clock::now();
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
d_output
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_table
=
context
.
Output
<
SelectedRows
>
(
framework
::
GradVarName
(
"W"
));
auto
*
ids_data
=
ids
->
data
<
int64_t
>
();
int64_t
ids_num
=
ids
->
numel
();
// auto end = std::chrono::system_clock::now();
// std::chrono::duration<double> diff = end - start;
// auto copy_start = std::chrono::system_clock::now();
std
::
vector
<
int64_t
>
new_rows
;
new_rows
.
resize
(
ids_num
);
std
::
memcpy
(
&
new_rows
[
0
],
ids_data
,
ids_num
*
sizeof
(
int64_t
));
// for (int64_t i = 0; i < ids_num; i++) {
// new_rows.push_back(ids_data[i]);
// }
// auto copy_end = std::chrono::system_clock::now();
// std::chrono::duration<double> copy_diff = copy_end - copy_start;
// diff += copy_diff;
// LOG(ERROR) << "run emb_grad copy end, cost: " << copy_diff.count() << "
// " << ids_num;
// copy_start = std::chrono::system_clock::now();
d_table
->
set_rows
(
new_rows
);
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table_dim
[
1
]});
d_table_value
->
ShareDataWith
(
*
d_output
);
// d_table_value->mutable_data<T>(context.GetPlace());
// // copy_end = std::chrono::system_clock::now();
// // copy_diff = copy_end - copy_start;
// // diff += copy_diff;
// // LOG(ERROR) << "run emb_grad resize table end, cost: " <<
// // copy_diff.count() << " " << ids_num;
// // copy_start = std::chrono::system_clock::now();
// d_table->set_height(table_dim[0]);
// auto *d_output_data = d_output->data<T>();
// auto *d_table_data = d_table_value->data<T>();
// // copy_end = std::chrono::system_clock::now();
// // copy_diff = copy_end - copy_start;
// // diff += copy_diff;
// // LOG(ERROR) << "run emb_grad set height end, cost: " <<
// // copy_diff.count() << " " << ids_num;
// auto d_output_dims = d_output->dims();
// PADDLE_ENFORCE_EQ(
// d_table_value->dims(),
// framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
// // copy_start = std::chrono::system_clock::now();
// auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
// blas.VCOPY(d_output->numel(), d_output_data, d_table_data);
// cblas_scopy(d_output->numel(), d_output_data, 1, d_table_data, 1);
// // for (int i = 0; i != d_output->numel(), ++i) {
// // *(d_table_data++) = *(d_output_data++);
// // }
// // memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
// // copy_end = std::chrono::system_clock::now();
// // copy_diff = copy_end - copy_start;
// // diff += copy_diff;
// // LOG(ERROR) << "run emb_grad core end, cost: " << copy_diff.count()
// << "
// // " << ids_num << " " << d_output->numel();
// // LOG(ERROR) << "run emb_grad end, cost: " << diff.count();
}
else
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
d_output
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_table
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"W"
));
auto
*
ids_data
=
ids
->
data
<
int64_t
>
();
int
N
=
table_dim
[
0
];
int
D
=
table_dim
[
1
];
auto
*
d_output_data
=
d_output
->
data
<
T
>
();
auto
*
d_table_data
=
d_table
->
mutable_data
<
T
>
(
context
.
GetPlace
());
memset
(
d_table_data
,
0
,
d_table
->
numel
()
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids
->
numel
();
++
i
)
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
);
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
d_table_data
[
ids_data
[
i
]
*
D
+
j
]
+=
d_output_data
[
i
*
D
+
j
];
}
}
}
}
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录