未验证 提交 07462de8 编写于 作者: Y Yibing Liu 提交者: GitHub

Cherry-pick lod_reset fix to 1.4 (#16939)

test=release/1.4
上级 d1c5da26
...@@ -45,12 +45,16 @@ class InferVarTypeContext { ...@@ -45,12 +45,16 @@ class InferVarTypeContext {
virtual bool HasInput(const std::string& name) const { virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0; auto& inputs = op_->Inputs();
auto input = inputs.find(name);
return input != inputs.end() && !input->second.empty();
} }
virtual bool HasOutput(const std::string& name) const { virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0; auto& outputs = op_->Outputs();
auto output = outputs.find(name);
return output != outputs.end() && !output->second.empty();
} }
virtual const std::vector<std::string>& Input(const std::string& name) const { virtual const std::vector<std::string>& Input(const std::string& name) const {
......
...@@ -30,10 +30,10 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -30,10 +30,10 @@ class LoDResetOp : public framework::OperatorWithKernel {
if (!ctx->HasInput("Y")) { if (!ctx->HasInput("Y")) {
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod"); auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
PADDLE_ENFORCE_GT(level0.size(), 1, PADDLE_ENFORCE_GT(level0.size(), 0,
"If Input(Y) not provided, the target lod should be " "If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`."); "specified by attribute `target_lod`.");
} else { } else if (ctx->IsRuntime()) {
ctx->ShareLoD("Y", "Out"); ctx->ShareLoD("Y", "Out");
} }
...@@ -48,6 +48,23 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,23 @@ class LoDResetOp : public framework::OperatorWithKernel {
} }
}; };
class LoDResetOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_var_name = ctx->Input("X").front();
auto out_var_name = ctx->Output("Out").front();
if (ctx->HasInput("Y")) {
auto y_var_name = ctx->Input("Y").front();
auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1);
ctx->SetLoDLevel(out_var_name, y_lod_level);
} else {
ctx->SetLoDLevel(out_var_name, 1);
}
ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name));
ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
}
};
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -177,9 +194,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LoDResetGradNoNeedBufferVarInference, ...@@ -177,9 +194,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LoDResetGradNoNeedBufferVarInference,
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
ops::LoDResetGradDescMaker); ops::LoDResetGradDescMaker, ops::LoDResetOpVarTypeInference);
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp, REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
ops::LoDResetGradNoNeedBufferVarInference); ops::LoDResetGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>, lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>, ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
......
...@@ -63,7 +63,7 @@ class LoDResetKernel : public framework::OpKernel<T> { ...@@ -63,7 +63,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
"Target LoD should be a vector end with the " "Target LoD should be a vector end with the "
"first dimension of Input(X)."); "first dimension of Input(X).");
for (size_t i = 0; i < level0.size() - 1; ++i) { for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE(level0[i + 1] > level0[i], PADDLE_ENFORCE(level0[i + 1] >= level0[i],
"Target LoD should be an ascending vector."); "Target LoD should be an ascending vector.");
} }
......
...@@ -1333,7 +1333,15 @@ class TestBook(unittest.TestCase): ...@@ -1333,7 +1333,15 @@ class TestBook(unittest.TestCase):
x = layers.data(name='x', shape=[10], dtype='float32') x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data( y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=2) name='y', shape=[10, 20], dtype='float32', lod_level=2)
print(layers.lod_reset(x=x, y=y)) z = layers.lod_reset(x=x, y=y)
self.assertTrue(z.lod_level == 2)
# case 2
lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int64')
z = layers.lod_reset(x=x, y=lod_tensor_in)
self.assertTrue(z.lod_level == 1)
# case 3
z = layers.lod_reset(x=x, target_lod=[1, 2, 3])
self.assertTrue(z.lod_level == 1)
print(str(program)) print(str(program))
def test_label_smooth(self): def test_label_smooth(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册