From 08033c86349695c838bea0fd18054d2ee5328d6e Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 12 Feb 2020 10:02:01 -0600 Subject: [PATCH] fix traced layer with non persistable vars, test=develop (#22552) --- .../imperative/jit/program_desc_tracer.cc | 54 +++++++++--- .../imperative/jit/program_desc_tracer.h | 11 ++- python/paddle/fluid/dygraph/jit.py | 8 +- ...imperative_trace_non_persistable_inputs.py | 85 +++++++++++++++++++ 4 files changed, 141 insertions(+), 17 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.cc b/paddle/fluid/imperative/jit/program_desc_tracer.cc index 2e92facb06..be93a787d4 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.cc +++ b/paddle/fluid/imperative/jit/program_desc_tracer.cc @@ -25,6 +25,7 @@ namespace jit { class UniqueBlockVarGenerator { public: UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, + const VarBaseSet &non_exist_input_vars, framework::BlockDesc *block); std::string NameOf(const std::weak_ptr &var, @@ -33,7 +34,8 @@ class UniqueBlockVarGenerator { private: void InsertNewVarInBlock(const std::weak_ptr &var, const framework::VarDesc &ref_desc, - const std::string &name); + const std::string &name, + bool force_persistable = false); private: const VarDescMetaMap &all_vars_; @@ -46,13 +48,18 @@ class UniqueBlockVarGenerator { std::unordered_set existing_names_; }; -UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, - framework::BlockDesc *block) +UniqueBlockVarGenerator::UniqueBlockVarGenerator( + const VarDescMetaMap &all_vars, const VarBaseSet &non_exist_input_vars, + framework::BlockDesc *block) : all_vars_(all_vars), block_(block) { for (auto &var_pair : all_vars_) { auto *var_desc = var_pair.second.get(); if (var_desc->Persistable()) { InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name()); + } else if (non_exist_input_vars.count(var_pair.first.lock()) > 0) { + VLOG(10) << "Mark " << var_desc->Name() << " as persistable"; + InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name(), + /*force_persistable=*/true); } } } @@ -90,12 +97,15 @@ std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr &var, void UniqueBlockVarGenerator::InsertNewVarInBlock( const std::weak_ptr &var, const framework::VarDesc &var_desc, - const std::string &name) { + const std::string &name, bool force_persistable) { var_to_name_[var] = name; existing_names_.insert(name); auto *new_var_desc = block_->Var(name); *new_var_desc = var_desc; new_var_desc->SetName(name); + if (force_persistable) { + new_var_desc->SetPersistable(true); + } } void ProgramDescTracer::InsertOp(const std::string &type, @@ -106,13 +116,13 @@ void ProgramDescTracer::InsertOp(const std::string &type, auto &new_op = ops_.back(); for (auto &pair : new_op->Inputs()) { for (auto &var : pair.second) { - InsertVarIfNotExist(var.lock()); + InsertVarIfNotExist(var.lock(), true); } } for (auto &pair : new_op->Outputs()) { for (auto &var : pair.second) { - InsertVarIfNotExist(var.lock()); + InsertVarIfNotExist(var.lock(), false); } } } @@ -125,7 +135,12 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( std::unique_ptr prog(new framework::ProgramDesc()); auto *block = prog->MutableBlock(0); - UniqueBlockVarGenerator generator(vars_, block); + auto non_exist_vars_copy = non_exist_input_vars_; + for (auto &feed_var : feed_vars) { + non_exist_vars_copy.erase(feed_var); + } + + UniqueBlockVarGenerator generator(vars_, non_exist_vars_copy, block); std::vector feed_var_names; for (auto &feed_var : feed_vars) { @@ -164,21 +179,37 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( } prog->Flush(); + + std::vector> persistable_vars( + non_exist_vars_copy.begin(), non_exist_vars_copy.end()); + for (auto &pair : vars_) { + if (pair.second->Persistable()) { + auto var = pair.first.lock(); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Persistable var %s does not exist", + pair.second->Name())); + persistable_vars.emplace_back(var); + } + } return std::make_tuple(std::move(prog), std::move(feed_var_names), - std::move(fetch_var_names)); + std::move(fetch_var_names), + std::move(persistable_vars)); } void ProgramDescTracer::InsertVarIfNotExist( - const std::shared_ptr &new_var) { + const std::shared_ptr &new_var, bool is_input) { PADDLE_ENFORCE_NOT_NULL(new_var); if (vars_.count(new_var) != 0) return; auto new_var_desc = new framework::VarDesc(""); vars_[new_var].reset(new_var_desc); - if (new_var->Persistable()) { + if (new_var->Persistable() || is_input) { new_var_desc->SetName(new_var->Name()); - new_var_desc->SetPersistable(true); + new_var_desc->SetPersistable(new_var->Persistable()); + if (!new_var->Persistable()) { + non_exist_input_vars_.insert(new_var); + } } else { new_var_desc->SetPersistable(false); } @@ -204,6 +235,7 @@ void ProgramDescTracer::InsertVarIfNotExist( void ProgramDescTracer::Reset() { ops_.clear(); vars_.clear(); + non_exist_input_vars_.clear(); } } // namespace jit diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.h b/paddle/fluid/imperative/jit/program_desc_tracer.h index 4ef29d0f44..d07acec223 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.h +++ b/paddle/fluid/imperative/jit/program_desc_tracer.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -34,10 +35,14 @@ using VarDescMetaMap = std::map, std::unique_ptr, std::owner_less>>; +using VarBaseSet = std::set, + std::owner_less>>; + using TracedProgramTuple = std::tuple /*program*/, std::vector /*feed_var_names*/, - std::vector /*fetch_var_names*/>; + std::vector /*fetch_var_names*/, + std::vector> /*persistable_vars*/>; class ProgramDescTracer { DISABLE_COPY_AND_ASSIGN(ProgramDescTracer); @@ -58,11 +63,13 @@ class ProgramDescTracer { void Reset(); private: - void InsertVarIfNotExist(const std::shared_ptr &new_var); + void InsertVarIfNotExist(const std::shared_ptr &new_var, + bool is_input); private: std::vector> ops_; VarDescMetaMap vars_; + VarBaseSet non_exist_input_vars_; }; } // namespace jit diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index e067e70443..c2de0f17d1 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -93,14 +93,14 @@ def _trace(layer, outputs = original_outputs out_vars = [var for var in outputs] - program_desc, feed_names, fetch_names = tracer.create_program_desc( + program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc( var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) tracer.reset() with _dygraph_guard(None): program = create_program_from_desc(program_desc) - return original_outputs, program, feed_names, fetch_names + return original_outputs, program, feed_names, fetch_names, parameters class TracedLayer(object): @@ -199,8 +199,8 @@ class TracedLayer(object): # save the static graph model for inference static_layer.save_inference_model(dirname='./saved_infer_model') """ - outs, prog, feed, fetch = _trace(layer, inputs) - traced = TracedLayer(prog, layer.parameters(), feed, fetch) + outs, prog, feed, fetch, parameters = _trace(layer, inputs) + traced = TracedLayer(prog, parameters, feed, fetch) return outs, traced def set_strategy(self, build_strategy=None, exec_strategy=None): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py b/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py new file mode 100644 index 0000000000..2a74d29e1e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import numpy as np +import six +import os + + +class SimpleFCLayer(fluid.dygraph.Layer): + def __init__(self, feature_size, batch_size, fc_size): + super(SimpleFCLayer, self).__init__() + self._linear = fluid.dygraph.Linear(feature_size, fc_size) + self._offset = fluid.dygraph.to_variable( + np.random.random((batch_size, fc_size)).astype('float32')) + + def forward(self, x): + fc = self._linear(x) + return fc + self._offset + + +class TestTracedLayerRecordNonPersistableInput(unittest.TestCase): + def test_main(self): + traced_layer = None + with fluid.dygraph.guard(): + feature_size = 3 + batch_size = 4 + fc_size = 2 + layer = SimpleFCLayer(feature_size, batch_size, fc_size) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3, + parameter_list=layer.parameters()) + + expected_persistable_vars = set([ + layer._linear.weight.name, layer._linear.bias.name, + layer._offset.name + ]) + + for _ in six.moves.range(10): + in_x = fluid.dygraph.to_variable( + np.random.random((batch_size, feature_size)).astype( + 'float32')) + if traced_layer is None: + dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace( + layer, [in_x]) + else: + dygraph_out = layer(in_x) + dygraph_out_numpy = dygraph_out.numpy() + static_out = traced_layer([in_x])[0] + self.assertTrue(np.array_equal(dygraph_out_numpy, static_out)) + + loss = fluid.layers.reduce_mean(dygraph_out) + loss.backward() + + optimizer.minimize(loss) + + del layer + + program = traced_layer.program + actual_persistable_vars = set() + for var in program.list_vars(): + if var.persistable: + actual_persistable_vars.add(var.name) + + self.assertEqual(actual_persistable_vars, expected_persistable_vars) + + dirname = './traced_layer_test_non_persistable_vars' + traced_layer.save_inference_model(dirname=dirname) + filenames = set([f for f in os.listdir(dirname) if f != '__model__']) + self.assertEqual(filenames, expected_persistable_vars) + + +if __name__ == '__main__': + unittest.main() -- GitLab