提交 d8b69735 编写于 作者: T tangwei12

update height_sections to int64_t

上级 318ba991
...@@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "The input SelectedRows."); AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable(); AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddComment(R"DOC( AddComment(R"DOC(
Split a SelectedRows with a specified rows section. Split a SelectedRows with a specified rows section.
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static int FindOutIdx(int row, const std::vector<int>& abs_sections) { static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) { for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) { if (row < abs_sections[i]) {
return i - 1; return i - 1;
...@@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) { ...@@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
return abs_sections.size() - 1; return abs_sections.size() - 1;
} }
static std::vector<int> ToAbsoluteSection( static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int>& height_sections) { const std::vector<int64_t>& height_sections) {
std::vector<int> abs_sections; std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size()); abs_sections.resize(height_sections.size());
abs_sections[0] = 0; abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) { for (size_t i = 1; i < height_sections.size(); ++i) {
...@@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::SelectedRows>("X"); auto* x = ctx.Input<framework::SelectedRows>("X");
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto abs_sections = ToAbsoluteSection(height_sections); auto abs_sections = ToAbsoluteSection(height_sections);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册