未验证 提交 08033c86 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix traced layer with non persistable vars, test=develop (#22552)

上级 31b54646
......@@ -25,6 +25,7 @@ namespace jit {
class UniqueBlockVarGenerator {
public:
UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
const VarBaseSet &non_exist_input_vars,
framework::BlockDesc *block);
std::string NameOf(const std::weak_ptr<VarBase> &var,
......@@ -33,7 +34,8 @@ class UniqueBlockVarGenerator {
private:
void InsertNewVarInBlock(const std::weak_ptr<VarBase> &var,
const framework::VarDesc &ref_desc,
const std::string &name);
const std::string &name,
bool force_persistable = false);
private:
const VarDescMetaMap &all_vars_;
......@@ -46,13 +48,18 @@ class UniqueBlockVarGenerator {
std::unordered_set<std::string> existing_names_;
};
UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
framework::BlockDesc *block)
UniqueBlockVarGenerator::UniqueBlockVarGenerator(
const VarDescMetaMap &all_vars, const VarBaseSet &non_exist_input_vars,
framework::BlockDesc *block)
: all_vars_(all_vars), block_(block) {
for (auto &var_pair : all_vars_) {
auto *var_desc = var_pair.second.get();
if (var_desc->Persistable()) {
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name());
} else if (non_exist_input_vars.count(var_pair.first.lock()) > 0) {
VLOG(10) << "Mark " << var_desc->Name() << " as persistable";
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name(),
/*force_persistable=*/true);
}
}
}
......@@ -90,12 +97,15 @@ std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
void UniqueBlockVarGenerator::InsertNewVarInBlock(
const std::weak_ptr<VarBase> &var, const framework::VarDesc &var_desc,
const std::string &name) {
const std::string &name, bool force_persistable) {
var_to_name_[var] = name;
existing_names_.insert(name);
auto *new_var_desc = block_->Var(name);
*new_var_desc = var_desc;
new_var_desc->SetName(name);
if (force_persistable) {
new_var_desc->SetPersistable(true);
}
}
void ProgramDescTracer::InsertOp(const std::string &type,
......@@ -106,13 +116,13 @@ void ProgramDescTracer::InsertOp(const std::string &type,
auto &new_op = ops_.back();
for (auto &pair : new_op->Inputs()) {
for (auto &var : pair.second) {
InsertVarIfNotExist(var.lock());
InsertVarIfNotExist(var.lock(), true);
}
}
for (auto &pair : new_op->Outputs()) {
for (auto &var : pair.second) {
InsertVarIfNotExist(var.lock());
InsertVarIfNotExist(var.lock(), false);
}
}
}
......@@ -125,7 +135,12 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
auto *block = prog->MutableBlock(0);
UniqueBlockVarGenerator generator(vars_, block);
auto non_exist_vars_copy = non_exist_input_vars_;
for (auto &feed_var : feed_vars) {
non_exist_vars_copy.erase(feed_var);
}
UniqueBlockVarGenerator generator(vars_, non_exist_vars_copy, block);
std::vector<std::string> feed_var_names;
for (auto &feed_var : feed_vars) {
......@@ -164,21 +179,37 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
}
prog->Flush();
std::vector<std::shared_ptr<VarBase>> persistable_vars(
non_exist_vars_copy.begin(), non_exist_vars_copy.end());
for (auto &pair : vars_) {
if (pair.second->Persistable()) {
auto var = pair.first.lock();
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Persistable var %s does not exist",
pair.second->Name()));
persistable_vars.emplace_back(var);
}
}
return std::make_tuple(std::move(prog), std::move(feed_var_names),
std::move(fetch_var_names));
std::move(fetch_var_names),
std::move(persistable_vars));
}
void ProgramDescTracer::InsertVarIfNotExist(
const std::shared_ptr<VarBase> &new_var) {
const std::shared_ptr<VarBase> &new_var, bool is_input) {
PADDLE_ENFORCE_NOT_NULL(new_var);
if (vars_.count(new_var) != 0) return;
auto new_var_desc = new framework::VarDesc("");
vars_[new_var].reset(new_var_desc);
if (new_var->Persistable()) {
if (new_var->Persistable() || is_input) {
new_var_desc->SetName(new_var->Name());
new_var_desc->SetPersistable(true);
new_var_desc->SetPersistable(new_var->Persistable());
if (!new_var->Persistable()) {
non_exist_input_vars_.insert(new_var);
}
} else {
new_var_desc->SetPersistable(false);
}
......@@ -204,6 +235,7 @@ void ProgramDescTracer::InsertVarIfNotExist(
void ProgramDescTracer::Reset() {
ops_.clear();
vars_.clear();
non_exist_input_vars_.clear();
}
} // namespace jit
......
......@@ -16,6 +16,7 @@
#include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <utility>
......@@ -34,10 +35,14 @@ using VarDescMetaMap =
std::map<std::weak_ptr<VarBase>, std::unique_ptr<framework::VarDesc>,
std::owner_less<std::weak_ptr<VarBase>>>;
using VarBaseSet = std::set<std::shared_ptr<VarBase>,
std::owner_less<std::shared_ptr<VarBase>>>;
using TracedProgramTuple =
std::tuple<std::unique_ptr<framework::ProgramDesc> /*program*/,
std::vector<std::string> /*feed_var_names*/,
std::vector<std::string> /*fetch_var_names*/>;
std::vector<std::string> /*fetch_var_names*/,
std::vector<std::shared_ptr<VarBase>> /*persistable_vars*/>;
class ProgramDescTracer {
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
......@@ -58,11 +63,13 @@ class ProgramDescTracer {
void Reset();
private:
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var);
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var,
bool is_input);
private:
std::vector<std::unique_ptr<OpDescMeta>> ops_;
VarDescMetaMap vars_;
VarBaseSet non_exist_input_vars_;
};
} // namespace jit
......
......@@ -93,14 +93,14 @@ def _trace(layer,
outputs = original_outputs
out_vars = [var for var in outputs]
program_desc, feed_names, fetch_names = tracer.create_program_desc(
program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
tracer.reset()
with _dygraph_guard(None):
program = create_program_from_desc(program_desc)
return original_outputs, program, feed_names, fetch_names
return original_outputs, program, feed_names, fetch_names, parameters
class TracedLayer(object):
......@@ -199,8 +199,8 @@ class TracedLayer(object):
# save the static graph model for inference
static_layer.save_inference_model(dirname='./saved_infer_model')
"""
outs, prog, feed, fetch = _trace(layer, inputs)
traced = TracedLayer(prog, layer.parameters(), feed, fetch)
outs, prog, feed, fetch, parameters = _trace(layer, inputs)
traced = TracedLayer(prog, parameters, feed, fetch)
return outs, traced
def set_strategy(self, build_strategy=None, exec_strategy=None):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import numpy as np
import six
import os
class SimpleFCLayer(fluid.dygraph.Layer):
def __init__(self, feature_size, batch_size, fc_size):
super(SimpleFCLayer, self).__init__()
self._linear = fluid.dygraph.Linear(feature_size, fc_size)
self._offset = fluid.dygraph.to_variable(
np.random.random((batch_size, fc_size)).astype('float32'))
def forward(self, x):
fc = self._linear(x)
return fc + self._offset
class TestTracedLayerRecordNonPersistableInput(unittest.TestCase):
def test_main(self):
traced_layer = None
with fluid.dygraph.guard():
feature_size = 3
batch_size = 4
fc_size = 2
layer = SimpleFCLayer(feature_size, batch_size, fc_size)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=layer.parameters())
expected_persistable_vars = set([
layer._linear.weight.name, layer._linear.bias.name,
layer._offset.name
])
for _ in six.moves.range(10):
in_x = fluid.dygraph.to_variable(
np.random.random((batch_size, feature_size)).astype(
'float32'))
if traced_layer is None:
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
layer, [in_x])
else:
dygraph_out = layer(in_x)
dygraph_out_numpy = dygraph_out.numpy()
static_out = traced_layer([in_x])[0]
self.assertTrue(np.array_equal(dygraph_out_numpy, static_out))
loss = fluid.layers.reduce_mean(dygraph_out)
loss.backward()
optimizer.minimize(loss)
del layer
program = traced_layer.program
actual_persistable_vars = set()
for var in program.list_vars():
if var.persistable:
actual_persistable_vars.add(var.name)
self.assertEqual(actual_persistable_vars, expected_persistable_vars)
dirname = './traced_layer_test_non_persistable_vars'
traced_layer.save_inference_model(dirname=dirname)
filenames = set([f for f in os.listdir(dirname) if f != '__model__'])
self.assertEqual(filenames, expected_persistable_vars)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册