未验证 提交 33ff833a 编写于 作者: C Chen Weihang 提交者: GitHub

fix loaded no params layer run error (#27241)

上级 f1ab2882
...@@ -27,9 +27,6 @@ class RunProgramOp : public framework::OperatorWithKernel { ...@@ -27,9 +27,6 @@ class RunProgramOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
platform::errors::NotFound( platform::errors::NotFound(
"Input(X) of RunProgramOp should not be null.")); "Input(X) of RunProgramOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInputs("Params"), true,
platform::errors::NotFound(
"Input(Params) of RunProgramOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutputs("Out"), true,
platform::errors::NotFound( platform::errors::NotFound(
"Output(Out) of RunProgramOp should not be null.")); "Output(Out) of RunProgramOp should not be null."));
...@@ -73,7 +70,8 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,7 +70,8 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<LoDTensor or SelecetedRows>)" "(vector<LoDTensor or SelecetedRows>)"
"The input parameter of RunProgram operator, also the parameters " "The input parameter of RunProgram operator, also the parameters "
"of the loaded program.") "of the loaded program.")
.AsDuplicable(); .AsDuplicable()
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"(vector<LoDTensor>)" "(vector<LoDTensor>)"
"The output tensors of RunProgram operator, also the fetch " "The output tensors of RunProgram operator, also the fetch "
...@@ -121,10 +119,6 @@ class RunProgramGradOp : public framework::OperatorWithKernel { ...@@ -121,10 +119,6 @@ class RunProgramGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
platform::errors::NotFound( platform::errors::NotFound(
"Input(X) of RunProgramGradOp should not be null.")); "Input(X) of RunProgramGradOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInputs("Params"), true,
platform::errors::NotFound(
"Input(Params) of RunProgramGradOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInputs(framework::GradVarName("Out")), true, ctx->HasInputs(framework::GradVarName("Out")), true,
platform::errors::NotFound( platform::errors::NotFound(
......
...@@ -209,9 +209,14 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -209,9 +209,14 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
auto output_vars = ctx.MultiOutputVar("Out"); auto output_vars = ctx.MultiOutputVar("Out");
auto input_var_names = ctx.InputNames("X"); auto input_var_names = ctx.InputNames("X");
auto param_names = ctx.InputNames("Params");
auto output_var_names = ctx.OutputNames("Out"); auto output_var_names = ctx.OutputNames("Out");
// current program may not hold parameters
std::vector<std::string> param_names;
if (!param_vars.empty()) {
param_names = ctx.InputNames("Params");
}
auto *block = ctx.Attr<BlockDesc *>("global_block"); auto *block = ctx.Attr<BlockDesc *>("global_block");
auto *program = block->Program(); auto *program = block->Program();
auto start_op_index = ctx.Attr<int64_t>("start_op_index"); auto start_op_index = ctx.Attr<int64_t>("start_op_index");
......
...@@ -479,11 +479,15 @@ def _load_persistable_vars(model_path, ...@@ -479,11 +479,15 @@ def _load_persistable_vars(model_path,
var_file_path = os.path.join(model_path, params_filename) var_file_path = os.path.join(model_path, params_filename)
else: else:
var_file_path = os.path.join(model_path, VARIABLE_FILENAME) var_file_path = os.path.join(model_path, VARIABLE_FILENAME)
framework._dygraph_tracer().trace_op( if not os.path.exists(var_file_path):
type='load_combine', if len(extra_var_info) != 0:
inputs={}, raise ValueError("The model to be loaded is incomplete.")
outputs={'Out': load_var_list}, else:
attrs={'file_path': var_file_path}) framework._dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
return load_var_dict return load_var_dict
......
...@@ -23,7 +23,7 @@ from paddle.static import InputSpec ...@@ -23,7 +23,7 @@ from paddle.static import InputSpec
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME
BATCH_SIZE = 32 BATCH_SIZE = 32
BATCH_NUM = 10 BATCH_NUM = 10
...@@ -153,6 +153,24 @@ class LinearNetReturnHidden(fluid.dygraph.Layer): ...@@ -153,6 +153,24 @@ class LinearNetReturnHidden(fluid.dygraph.Layer):
return y, loss return y, loss
class EmptyLayer(paddle.nn.Layer):
def __init__(self):
super(EmptyLayer, self).__init__()
@paddle.jit.to_static
def forward(self, x):
return x
class NoParamLayer(paddle.nn.Layer):
def __init__(self):
super(NoParamLayer, self).__init__()
@paddle.jit.to_static
def forward(self, x, y):
return x + y
def train(layer, input_size=784, label_size=1): def train(layer, input_size=784, label_size=1):
# create optimizer # create optimizer
sgd = fluid.optimizer.SGDOptimizer( sgd = fluid.optimizer.SGDOptimizer(
...@@ -273,6 +291,15 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -273,6 +291,15 @@ class TestJitSaveLoad(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path) model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_jit_load_model_incomplete(self):
model_path = "model.test_jit_save_load.remove_variables"
self.train_and_save_model(model_path=model_path)
# remove `__variables__`
var_path = os.path.join(model_path, VARIABLE_FILENAME)
os.remove(var_path)
with self.assertRaises(ValueError):
paddle.jit.load(model_path)
class TestSaveLoadWithInputSpec(unittest.TestCase): class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -695,5 +722,38 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -695,5 +722,38 @@ class TestJitSaveMultiCases(unittest.TestCase):
configs=configs) configs=configs)
class TestJitSaveLoadEmptyLayer(unittest.TestCase):
def setUp(self):
self.model_path = "model.jit_save_load_empty_layer"
# enable dygraph mode
paddle.disable_static()
def test_save_load_empty_layer(self):
layer = EmptyLayer()
x = paddle.to_variable(np.random.random((10)).astype('float32'))
out = layer(x)
paddle.jit.save(layer, self.model_path)
load_layer = paddle.jit.load(self.model_path)
load_out = load_layer(x)
self.assertTrue(np.array_equal(out, load_out))
class TestJitSaveLoadNoParamLayer(unittest.TestCase):
def setUp(self):
self.model_path = "model.jit_save_load_no_param_layer"
# enable dygraph mode
paddle.disable_static()
def test_save_load_no_param_layer(self):
layer = NoParamLayer()
x = paddle.to_variable(np.random.random((5)).astype('float32'))
y = paddle.to_variable(np.random.random((5)).astype('float32'))
out = layer(x, y)
paddle.jit.save(layer, self.model_path)
load_layer = paddle.jit.load(self.model_path)
load_out = load_layer(x, y)
self.assertTrue(np.array_equal(out, load_out))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册