未验证 提交 4267a81a 编写于 作者: Y Yibing Liu 提交者: GitHub

Correct the lod level of compiled time in lod_reset (#16790)

test=develop
上级 1b750494
......@@ -45,12 +45,16 @@ class InferVarTypeContext {
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
auto& inputs = op_->Inputs();
auto input = inputs.find(name);
return input != inputs.end() && !input->second.empty();
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
auto& outputs = op_->Outputs();
auto output = outputs.find(name);
return output != outputs.end() && !output->second.empty();
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
......
......@@ -30,10 +30,10 @@ class LoDResetOp : public framework::OperatorWithKernel {
if (!ctx->HasInput("Y")) {
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
PADDLE_ENFORCE_GT(level0.size(), 1,
PADDLE_ENFORCE_GT(level0.size(), 0,
"If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`.");
} else {
} else if (ctx->IsRuntime()) {
ctx->ShareLoD("Y", "Out");
}
......@@ -48,6 +48,23 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
};
class LoDResetOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_var_name = ctx->Input("X").front();
auto out_var_name = ctx->Output("Out").front();
if (ctx->HasInput("Y")) {
auto y_var_name = ctx->Input("Y").front();
auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1);
ctx->SetLoDLevel(out_var_name, y_lod_level);
} else {
ctx->SetLoDLevel(out_var_name, 1);
}
ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name));
ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
}
};
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -177,9 +194,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LoDResetGradNoNeedBufferVarInference,
namespace ops = paddle::operators;
REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
ops::LoDResetGradDescMaker);
ops::LoDResetGradDescMaker, ops::LoDResetOpVarTypeInference);
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
ops::LoDResetGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
......
......@@ -63,7 +63,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE(level0[i + 1] > level0[i],
PADDLE_ENFORCE(level0[i + 1] >= level0[i],
"Target LoD should be an ascending vector.");
}
......
......@@ -1759,10 +1759,20 @@ class TestBook(LayerTest):
def test_lod_reset(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
# case 1
x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=2)
return (layers.lod_reset(x=x, y=y))
z = layers.lod_reset(x=x, y=y)
self.assertTrue(z.lod_level == 2)
# case 2
lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int64')
z = layers.lod_reset(x=x, y=lod_tensor_in)
self.assertTrue(z.lod_level == 1)
# case 3
z = layers.lod_reset(x=x, target_lod=[1, 2, 3])
self.assertTrue(z.lod_level == 1)
return z
def test_affine_grid(self):
with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册