未验证 提交 4e53be82 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix traced layer with non persistable vars, test=release/1.7 (#22554)

上级 167de0fa
...@@ -25,6 +25,7 @@ namespace jit { ...@@ -25,6 +25,7 @@ namespace jit {
class UniqueBlockVarGenerator { class UniqueBlockVarGenerator {
public: public:
UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
const VarBaseSet &non_exist_input_vars,
framework::BlockDesc *block); framework::BlockDesc *block);
std::string NameOf(const std::weak_ptr<VarBase> &var, std::string NameOf(const std::weak_ptr<VarBase> &var,
...@@ -33,7 +34,8 @@ class UniqueBlockVarGenerator { ...@@ -33,7 +34,8 @@ class UniqueBlockVarGenerator {
private: private:
void InsertNewVarInBlock(const std::weak_ptr<VarBase> &var, void InsertNewVarInBlock(const std::weak_ptr<VarBase> &var,
const framework::VarDesc &ref_desc, const framework::VarDesc &ref_desc,
const std::string &name); const std::string &name,
bool force_persistable = false);
private: private:
const VarDescMetaMap &all_vars_; const VarDescMetaMap &all_vars_;
...@@ -46,13 +48,18 @@ class UniqueBlockVarGenerator { ...@@ -46,13 +48,18 @@ class UniqueBlockVarGenerator {
std::unordered_set<std::string> existing_names_; std::unordered_set<std::string> existing_names_;
}; };
UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, UniqueBlockVarGenerator::UniqueBlockVarGenerator(
const VarDescMetaMap &all_vars, const VarBaseSet &non_exist_input_vars,
framework::BlockDesc *block) framework::BlockDesc *block)
: all_vars_(all_vars), block_(block) { : all_vars_(all_vars), block_(block) {
for (auto &var_pair : all_vars_) { for (auto &var_pair : all_vars_) {
auto *var_desc = var_pair.second.get(); auto *var_desc = var_pair.second.get();
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name()); InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name());
} else if (non_exist_input_vars.count(var_pair.first.lock()) > 0) {
VLOG(10) << "Mark " << var_desc->Name() << " as persistable";
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name(),
/*force_persistable=*/true);
} }
} }
} }
...@@ -90,12 +97,15 @@ std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var, ...@@ -90,12 +97,15 @@ std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
void UniqueBlockVarGenerator::InsertNewVarInBlock( void UniqueBlockVarGenerator::InsertNewVarInBlock(
const std::weak_ptr<VarBase> &var, const framework::VarDesc &var_desc, const std::weak_ptr<VarBase> &var, const framework::VarDesc &var_desc,
const std::string &name) { const std::string &name, bool force_persistable) {
var_to_name_[var] = name; var_to_name_[var] = name;
existing_names_.insert(name); existing_names_.insert(name);
auto *new_var_desc = block_->Var(name); auto *new_var_desc = block_->Var(name);
*new_var_desc = var_desc; *new_var_desc = var_desc;
new_var_desc->SetName(name); new_var_desc->SetName(name);
if (force_persistable) {
new_var_desc->SetPersistable(true);
}
} }
void ProgramDescTracer::InsertOp(const std::string &type, void ProgramDescTracer::InsertOp(const std::string &type,
...@@ -106,13 +116,13 @@ void ProgramDescTracer::InsertOp(const std::string &type, ...@@ -106,13 +116,13 @@ void ProgramDescTracer::InsertOp(const std::string &type,
auto &new_op = ops_.back(); auto &new_op = ops_.back();
for (auto &pair : new_op->Inputs()) { for (auto &pair : new_op->Inputs()) {
for (auto &var : pair.second) { for (auto &var : pair.second) {
InsertVarIfNotExist(var.lock()); InsertVarIfNotExist(var.lock(), true);
} }
} }
for (auto &pair : new_op->Outputs()) { for (auto &pair : new_op->Outputs()) {
for (auto &var : pair.second) { for (auto &var : pair.second) {
InsertVarIfNotExist(var.lock()); InsertVarIfNotExist(var.lock(), false);
} }
} }
} }
...@@ -125,7 +135,12 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( ...@@ -125,7 +135,12 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc()); std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
auto *block = prog->MutableBlock(0); auto *block = prog->MutableBlock(0);
UniqueBlockVarGenerator generator(vars_, block); auto non_exist_vars_copy = non_exist_input_vars_;
for (auto &feed_var : feed_vars) {
non_exist_vars_copy.erase(feed_var);
}
UniqueBlockVarGenerator generator(vars_, non_exist_vars_copy, block);
std::vector<std::string> feed_var_names; std::vector<std::string> feed_var_names;
for (auto &feed_var : feed_vars) { for (auto &feed_var : feed_vars) {
...@@ -164,21 +179,37 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( ...@@ -164,21 +179,37 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
} }
prog->Flush(); prog->Flush();
std::vector<std::shared_ptr<VarBase>> persistable_vars(
non_exist_vars_copy.begin(), non_exist_vars_copy.end());
for (auto &pair : vars_) {
if (pair.second->Persistable()) {
auto var = pair.first.lock();
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Persistable var %s does not exist",
pair.second->Name()));
persistable_vars.emplace_back(var);
}
}
return std::make_tuple(std::move(prog), std::move(feed_var_names), return std::make_tuple(std::move(prog), std::move(feed_var_names),
std::move(fetch_var_names)); std::move(fetch_var_names),
std::move(persistable_vars));
} }
void ProgramDescTracer::InsertVarIfNotExist( void ProgramDescTracer::InsertVarIfNotExist(
const std::shared_ptr<VarBase> &new_var) { const std::shared_ptr<VarBase> &new_var, bool is_input) {
PADDLE_ENFORCE_NOT_NULL(new_var); PADDLE_ENFORCE_NOT_NULL(new_var);
if (vars_.count(new_var) != 0) return; if (vars_.count(new_var) != 0) return;
auto new_var_desc = new framework::VarDesc(""); auto new_var_desc = new framework::VarDesc("");
vars_[new_var].reset(new_var_desc); vars_[new_var].reset(new_var_desc);
if (new_var->Persistable()) { if (new_var->Persistable() || is_input) {
new_var_desc->SetName(new_var->Name()); new_var_desc->SetName(new_var->Name());
new_var_desc->SetPersistable(true); new_var_desc->SetPersistable(new_var->Persistable());
if (!new_var->Persistable()) {
non_exist_input_vars_.insert(new_var);
}
} else { } else {
new_var_desc->SetPersistable(false); new_var_desc->SetPersistable(false);
} }
...@@ -204,6 +235,7 @@ void ProgramDescTracer::InsertVarIfNotExist( ...@@ -204,6 +235,7 @@ void ProgramDescTracer::InsertVarIfNotExist(
void ProgramDescTracer::Reset() { void ProgramDescTracer::Reset() {
ops_.clear(); ops_.clear();
vars_.clear(); vars_.clear();
non_exist_input_vars_.clear();
} }
} // namespace jit } // namespace jit
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
...@@ -34,10 +35,14 @@ using VarDescMetaMap = ...@@ -34,10 +35,14 @@ using VarDescMetaMap =
std::map<std::weak_ptr<VarBase>, std::unique_ptr<framework::VarDesc>, std::map<std::weak_ptr<VarBase>, std::unique_ptr<framework::VarDesc>,
std::owner_less<std::weak_ptr<VarBase>>>; std::owner_less<std::weak_ptr<VarBase>>>;
using VarBaseSet = std::set<std::shared_ptr<VarBase>,
std::owner_less<std::shared_ptr<VarBase>>>;
using TracedProgramTuple = using TracedProgramTuple =
std::tuple<std::unique_ptr<framework::ProgramDesc> /*program*/, std::tuple<std::unique_ptr<framework::ProgramDesc> /*program*/,
std::vector<std::string> /*feed_var_names*/, std::vector<std::string> /*feed_var_names*/,
std::vector<std::string> /*fetch_var_names*/>; std::vector<std::string> /*fetch_var_names*/,
std::vector<std::shared_ptr<VarBase>> /*persistable_vars*/>;
class ProgramDescTracer { class ProgramDescTracer {
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer); DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
...@@ -58,11 +63,13 @@ class ProgramDescTracer { ...@@ -58,11 +63,13 @@ class ProgramDescTracer {
void Reset(); void Reset();
private: private:
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var); void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var,
bool is_input);
private: private:
std::vector<std::unique_ptr<OpDescMeta>> ops_; std::vector<std::unique_ptr<OpDescMeta>> ops_;
VarDescMetaMap vars_; VarDescMetaMap vars_;
VarBaseSet non_exist_input_vars_;
}; };
} // namespace jit } // namespace jit
......
...@@ -68,14 +68,14 @@ def _trace(layer, ...@@ -68,14 +68,14 @@ def _trace(layer,
outputs = original_outputs outputs = original_outputs
out_vars = [var for var in outputs] out_vars = [var for var in outputs]
program_desc, feed_names, fetch_names = tracer.create_program_desc( program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
tracer.reset() tracer.reset()
with _dygraph_guard(None): with _dygraph_guard(None):
program = create_program_from_desc(program_desc) program = create_program_from_desc(program_desc)
return original_outputs, program, feed_names, fetch_names return original_outputs, program, feed_names, fetch_names, parameters
class TracedLayer(object): class TracedLayer(object):
...@@ -174,8 +174,8 @@ class TracedLayer(object): ...@@ -174,8 +174,8 @@ class TracedLayer(object):
# save the static graph model for inference # save the static graph model for inference
static_layer.save_inference_model(dirname='./saved_infer_model') static_layer.save_inference_model(dirname='./saved_infer_model')
""" """
outs, prog, feed, fetch = _trace(layer, inputs) outs, prog, feed, fetch, parameters = _trace(layer, inputs)
traced = TracedLayer(prog, layer.parameters(), feed, fetch) traced = TracedLayer(prog, parameters, feed, fetch)
return outs, traced return outs, traced
def set_strategy(self, build_strategy=None, exec_strategy=None): def set_strategy(self, build_strategy=None, exec_strategy=None):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import numpy as np
import six
import os
class SimpleFCLayer(fluid.dygraph.Layer):
def __init__(self, feature_size, batch_size, fc_size):
super(SimpleFCLayer, self).__init__()
self._linear = fluid.dygraph.Linear(feature_size, fc_size)
self._offset = fluid.dygraph.to_variable(
np.random.random((batch_size, fc_size)).astype('float32'))
def forward(self, x):
fc = self._linear(x)
return fc + self._offset
class TestTracedLayerRecordNonPersistableInput(unittest.TestCase):
def test_main(self):
traced_layer = None
with fluid.dygraph.guard():
feature_size = 3
batch_size = 4
fc_size = 2
layer = SimpleFCLayer(feature_size, batch_size, fc_size)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=layer.parameters())
expected_persistable_vars = set([
layer._linear.weight.name, layer._linear.bias.name,
layer._offset.name
])
for _ in six.moves.range(10):
in_x = fluid.dygraph.to_variable(
np.random.random((batch_size, feature_size)).astype(
'float32'))
if traced_layer is None:
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
layer, [in_x])
else:
dygraph_out = layer(in_x)
dygraph_out_numpy = dygraph_out.numpy()
static_out = traced_layer([in_x])[0]
self.assertTrue(np.array_equal(dygraph_out_numpy, static_out))
loss = fluid.layers.reduce_mean(dygraph_out)
loss.backward()
optimizer.minimize(loss)
del layer
program = traced_layer.program
actual_persistable_vars = set()
for var in program.list_vars():
if var.persistable:
actual_persistable_vars.add(var.name)
self.assertEqual(actual_persistable_vars, expected_persistable_vars)
dirname = './traced_layer_test_non_persistable_vars'
traced_layer.save_inference_model(dirname=dirname)
filenames = set([f for f in os.listdir(dirname) if f != '__model__'])
self.assertEqual(filenames, expected_persistable_vars)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册