未验证 提交 8ff6b289 编写于 作者: Z Zeng Jinle 提交者: GitHub

[Dygraph to static graph]JIT/Trace (#20775)

* jit/trace 1st version, test=develop

* add more unittests, test=develop
上级 6e6eab07
......@@ -3,7 +3,9 @@ cc_library(imperative_flag SRCS flags.cc DEPS gflags)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform)
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows var_type_traits layer)
cc_library(tracer SRCS tracer.cc DEPS layer engine)
add_subdirectory(jit)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer)
cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context)
......
cc_library(op_desc_meta SRCS op_desc_meta.cc DEPS proto_desc layer)
cc_library(program_desc_tracer SRCS program_desc_tracer.cc DEPS op_desc_meta)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/jit/op_desc_meta.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle {
namespace imperative {
namespace jit {
OpDescMeta::OpDescMeta(const std::string &type, const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs)
: type_(type), attrs_(attrs) {
auto *proto = framework::OpInfoMap::Instance().GetNullable(type_);
if (proto && proto->Checker()) {
proto->Checker()->Check(&attrs_);
}
for (auto &pair : inputs) {
inputs_[pair.first].assign(pair.second.begin(), pair.second.end());
}
for (auto &pair : outputs) {
outputs_[pair.first].assign(pair.second.begin(), pair.second.end());
}
}
const std::string &OpDescMeta::Type() const { return type_; }
const WeakNameVarBaseMap &OpDescMeta::Inputs() const { return inputs_; }
const WeakNameVarBaseMap &OpDescMeta::Outputs() const { return outputs_; }
const framework::AttributeMap &OpDescMeta::Attrs() const { return attrs_; }
} // namespace jit
} // namespace imperative
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
namespace imperative {
namespace jit {
class OpDescMeta {
public:
OpDescMeta(const std::string &type, const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs);
const std::string &Type() const;
const WeakNameVarBaseMap &Inputs() const;
const WeakNameVarBaseMap &Outputs() const;
const framework::AttributeMap &Attrs() const;
private:
std::string type_;
WeakNameVarBaseMap inputs_;
WeakNameVarBaseMap outputs_;
framework::AttributeMap attrs_;
};
} // namespace jit
} // namespace imperative
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
namespace paddle {
namespace imperative {
namespace jit {
void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) {
name_prefix_ = name_prefix;
}
void ProgramDescTracer::SetFeedVars(
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
std::vector<std::string> feed_names) {
feed_vars_.clear();
if (feed_names.empty()) {
feed_names.reserve(feed_vars.size());
for (auto &var : feed_vars) {
feed_names.emplace_back(var->Name());
}
}
PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(),
"The feeded variable names number must be equal to the "
"feeded variable number");
for (size_t i = 0; i < feed_names.size(); ++i) {
feed_vars_[feed_vars[i]] = feed_names[i];
}
}
void ProgramDescTracer::SetFetchVars(
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
std::vector<std::string> fetch_names) {
fetch_vars_.clear();
if (fetch_names.empty()) {
fetch_names.reserve(fetch_vars.size());
for (auto &var : fetch_vars) {
fetch_names.emplace_back(var->Name());
}
}
PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(),
"The fetched variable names number must be equal to the "
"fetched variable number");
for (size_t i = 0; i < fetch_names.size(); ++i) {
fetch_vars_[fetch_vars[i]] = fetch_names[i];
}
}
void ProgramDescTracer::InsertOp(const std::string &type,
const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs) {
ops_.emplace_back(new OpDescMeta(type, inputs, outputs, attrs));
auto &new_op = ops_.back();
for (auto &pair : new_op->Inputs()) {
for (auto &var : pair.second) {
InsertVarIfNotExist(var.lock());
}
}
for (auto &pair : new_op->Outputs()) {
for (auto &var : pair.second) {
InsertVarIfNotExist(var.lock());
}
}
}
std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
const {
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
auto *block = prog->MutableBlock(0);
size_t var_num = vars_.size();
std::vector<framework::VarDesc *> var_descs(var_num, nullptr);
std::unordered_map<framework::VarDesc *, std::weak_ptr<VarBase>>
var_desc_to_var_base;
for (auto &pair : vars_) {
size_t var_id = pair.second.first;
PADDLE_ENFORCE_LT(var_id, var_num);
var_descs[var_id] = pair.second.second.get();
PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]);
var_desc_to_var_base[var_descs[var_id]] = pair.first;
}
std::unordered_set<std::string> existing_var_names;
for (auto *var_desc : var_descs) {
if (var_desc->Persistable()) {
existing_var_names.insert(var_desc->Name());
}
}
for (auto &pair : feed_vars_) {
existing_var_names.insert(pair.second);
}
for (auto &pair : fetch_vars_) {
existing_var_names.insert(pair.second);
}
size_t counter = 0;
auto generate_unique_name = [&]() -> std::string {
do {
auto name = name_prefix_ + std::to_string(counter++);
if (existing_var_names.count(name) == 0) {
existing_var_names.insert(name);
return name;
}
} while (counter > 0);
PADDLE_THROW("Too many vars in the program");
};
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
var_to_name;
for (auto *var_desc : var_descs) {
auto var_name = var_desc->Name();
PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1);
std::weak_ptr<VarBase> var_base = var_desc_to_var_base.at(var_desc);
if (feed_vars_.count(var_base) > 0) {
var_name = feed_vars_.at(var_base);
} else if (fetch_vars_.count(var_base) > 0) {
var_name = fetch_vars_.at(var_base);
} else if (!var_desc->Persistable()) {
var_name = generate_unique_name();
}
auto *new_var_desc = block->Var(var_name);
*new_var_desc = *var_desc;
new_var_desc->SetName(std::move(var_name));
var_to_name[var_base] = new_var_desc->Name();
}
for (auto &op : ops_) {
auto *op_desc = block->AppendOp();
op_desc->SetType(op->Type());
op_desc->SetAttrMap(op->Attrs());
for (auto &pair : op->Inputs()) {
std::vector<std::string> names;
names.reserve(pair.second.size());
for (auto &var : pair.second) {
auto iter = var_to_name.find(var);
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
"Cannot find input variable");
names.emplace_back(iter->second);
}
op_desc->SetInput(pair.first, std::move(names));
}
for (auto &pair : op->Outputs()) {
std::vector<std::string> names;
names.reserve(pair.second.size());
for (auto &var : pair.second) {
auto iter = var_to_name.find(var);
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
"Cannot find output variable");
names.emplace_back(iter->second);
}
op_desc->SetOutput(pair.first, std::move(names));
}
}
prog->Flush();
return prog;
}
void ProgramDescTracer::InsertVarIfNotExist(
const std::shared_ptr<VarBase> &new_var) {
PADDLE_ENFORCE_NOT_NULL(new_var);
if (vars_.count(new_var) != 0) return;
size_t var_id = vars_.size();
auto new_var_desc = new framework::VarDesc("");
vars_[new_var] =
std::make_pair(var_id, std::unique_ptr<framework::VarDesc>(new_var_desc));
if (new_var->Persistable()) {
new_var_desc->SetName(new_var->Name());
new_var_desc->SetPersistable(true);
} else {
new_var_desc->SetPersistable(false);
}
const auto &inner_var = new_var->Var();
PADDLE_ENFORCE_EQ(inner_var.IsInitialized(), true);
if (inner_var.IsType<framework::LoDTensor>()) {
const auto &tensor = inner_var.Get<framework::LoDTensor>();
new_var_desc->SetType(framework::proto::VarType::LOD_TENSOR);
new_var_desc->SetShape(framework::vectorize<int64_t>(tensor.dims()));
new_var_desc->SetLoDLevel(tensor.lod().size());
if (tensor.IsInitialized()) {
new_var_desc->SetDataType(tensor.type());
} else {
new_var_desc->SetDataType(framework::proto::VarType::FP32);
}
} else {
PADDLE_THROW("Not support variable type %s",
framework::ToTypeName(inner_var.Type()));
}
}
void ProgramDescTracer::Reset() {
ops_.clear();
vars_.clear();
feed_vars_.clear();
fetch_vars_.clear();
name_prefix_.clear();
}
} // namespace jit
} // namespace imperative
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <forward_list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/imperative/jit/op_desc_meta.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace imperative {
namespace jit {
class ProgramDescTracer {
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
public:
ProgramDescTracer() = default;
void SetNamePrefix(const std::string &name_prefix);
void SetFeedVars(const std::vector<std::shared_ptr<VarBase>> &feed_vars,
std::vector<std::string> feed_names);
void SetFetchVars(const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
std::vector<std::string> fetch_names);
void InsertOp(const std::string &type, const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs);
std::unique_ptr<framework::ProgramDesc> CreateProgramDesc() const;
void Reset();
private:
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var);
std::vector<std::unique_ptr<OpDescMeta>> ops_;
std::map<std::weak_ptr<VarBase>,
std::pair<size_t, std::unique_ptr<framework::VarDesc>>,
std::owner_less<std::weak_ptr<VarBase>>>
vars_;
// The following fields are used to polish the converted ProgramDesc
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
feed_vars_;
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
fetch_vars_;
std::string name_prefix_;
};
} // namespace jit
} // namespace imperative
} // namespace paddle
......@@ -51,6 +51,11 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
auto op = OpBase::Create(op_id, type, ins, outs, std::move(attrs), place);
op->Run(ins, outs);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
}
if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op, framework::OpDesc(op->Type(), op->InputNameMap(),
op->OutputNameMap(), op->Attrs()),
......
......@@ -22,6 +22,7 @@
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/macros.h"
......@@ -32,7 +33,9 @@ class Tracer {
DISABLE_COPY_AND_ASSIGN(Tracer);
public:
Tracer() : engine_(new BasicEngine()) {}
Tracer()
: engine_(new BasicEngine()),
program_desc_tracer_(new jit::ProgramDescTracer()) {}
~Tracer() = default;
......@@ -46,8 +49,21 @@ class Tracer {
void TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const framework::OpDesc& fwd_op_desc,
const NameVarBaseMap& ins, const NameVarBaseMap& outs);
Engine* GetDefaultEngine() const { return engine_.get(); }
void SetEnableProgramDescTracing(bool enabled) {
enable_program_desc_tracing_ = enabled;
}
bool IsProgramDescTracingEnabled() const {
return enable_program_desc_tracing_;
}
jit::ProgramDescTracer* GetProgramDescTracer() {
return program_desc_tracer_.get();
}
private:
static size_t GenerateUniqueId() {
static std::atomic<size_t> id{0};
......@@ -56,6 +72,8 @@ class Tracer {
private:
std::unique_ptr<Engine> engine_;
std::unique_ptr<jit::ProgramDescTracer> program_desc_tracer_;
bool enable_program_desc_tracing_{false};
};
} // namespace imperative
......
......@@ -29,5 +29,8 @@ class Tracer;
using NameVarBaseMap =
std::map<std::string, std::vector<std::shared_ptr<VarBase>>>;
using WeakNameVarBaseMap =
std::map<std::string, std::vector<std::weak_ptr<VarBase>>>;
} // namespace imperative
} // namespace paddle
......@@ -320,9 +320,24 @@ void BindImperative(py::module *m_ptr) {
return self.Forward(inputs);
});
py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
.def("set_name_prefix",
&imperative::jit::ProgramDescTracer::SetNamePrefix)
.def("set_feed_vars", &imperative::jit::ProgramDescTracer::SetFeedVars)
.def("set_fetch_vars", &imperative::jit::ProgramDescTracer::SetFetchVars)
.def("create_program_desc",
&imperative::jit::ProgramDescTracer::CreateProgramDesc)
.def("reset", &imperative::jit::ProgramDescTracer::Reset);
py::class_<imperative::Tracer>(m, "Tracer", "")
.def("__init__",
[](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def("_get_program_desc_tracer",
&imperative::Tracer::GetProgramDescTracer,
py::return_value_policy::reference)
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
......
......@@ -27,6 +27,17 @@ __all__ = [
]
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
tracer = framework._dygraph_tracer()
if tracer:
original_val = tracer._enable_program_desc_tracing
tracer._enable_program_desc_tracing = enable
yield
if tracer:
tracer._enable_program_desc_tracing = original_val
# This function should be removed in V1.6, because it can easily lead to cyclic dependencies.
def enabled():
# Internal use only
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['trace']
from . import layers
from .base import program_desc_tracing_guard
from .layers import Layer
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard
def create_program_from_desc(program_desc):
program = Program()
program.desc = program_desc
program.blocks = [Block(program, 0)]
program._sync_with_cpp()
return program
def _extract_vars(inputs, result_list):
if isinstance(inputs, Variable):
result_list.append(inputs._ivar)
if isinstance(inputs, (list, tuple)):
for var in inputs:
_extract_vars(var, result_list)
def extract_vars(inputs):
result_list = []
_extract_vars(inputs, result_list)
return result_list
@dygraph_only
def trace(module, inputs, feed_names=None, fetch_names=None):
assert isinstance(module, Layer)
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
if feed_names is None:
feed_names = []
if fetch_names is None:
fetch_names = []
tracer = _dygraph_tracer()._get_program_desc_tracer()
var_list = extract_vars(inputs)
tracer.set_feed_vars(var_list, feed_names)
with program_desc_tracing_guard(True):
original_outputs = module.__call__(*inputs)
if not isinstance(original_outputs, (list, tuple)):
outputs = [original_outputs]
else:
outputs = original_outputs
out_vars = [var._ivar for var in outputs]
tracer.set_fetch_vars(out_vars, fetch_names)
tracer.set_name_prefix('t_')
program_desc = tracer.create_program_desc()
tracer.reset()
with _dygraph_guard(None):
program = create_program_from_desc(program_desc)
return original_outputs, program
......@@ -22,6 +22,7 @@ from . import parallel_helper
from .. import unique_name
from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper
from .base import program_desc_tracing_guard
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.framework import Variable
......@@ -171,9 +172,11 @@ class Layer(core.Layer):
def __call__(self, *inputs, **kwargs):
if not self._built:
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(self._parameters.values())
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
outputs = self.forward(*inputs, **kwargs)
self._built = True
......
......@@ -26,6 +26,7 @@ from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper
class SimpleImgConvPool(fluid.dygraph.Layer):
......@@ -135,6 +136,9 @@ class TestImperativeMnist(unittest.TestCase):
mnist.train()
dy_param_init_value = {}
helper = DyGraphProgramDescTracerTestHelper(mnist, self)
for epoch in range(epoch_num):
for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num:
......@@ -144,7 +148,14 @@ class TestImperativeMnist(unittest.TestCase):
label = data[1]
label.stop_gradient = True
cost = mnist(img)
if batch_id % 10 == 0:
cost, cost_static = helper.run(inputs=img,
feed_names=['image'],
fetch_names=['cost'])
helper.assertEachVar(cost, cost_static)
else:
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
......
......@@ -24,6 +24,7 @@ from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
import numpy as np
import six
from utils import DyGraphProgramDescTracerTestHelper
class SimpleLSTMRNN(fluid.Layer):
......@@ -239,6 +240,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
last_hidden = None
last_cell = None
helper = DyGraphProgramDescTracerTestHelper(ptb_model, self)
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
......@@ -252,8 +255,17 @@ class TestDygraphPtbRnn(unittest.TestCase):
y = to_variable(y_data)
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell)
if i % 5 == 0:
outs, outs_static = helper.run(
[x, y, init_hidden, init_cell],
feed_names=['x', 'y', 'init_hidden', 'init_cell'],
fetch_names=['dy_loss', 'last_hidden', 'last_cell'])
helper.assertEachVar(outs, outs_static)
else:
outs = ptb_model(x, y, init_hidden, init_cell)
dy_loss, last_hidden, last_cell = outs
if i == 0:
for param in ptb_model.parameters():
dy_param_init[param.name] = param.numpy()
......
......@@ -24,6 +24,7 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper
batch_size = 8
train_parameters = {
......@@ -248,6 +249,8 @@ class TestDygraphResnet(unittest.TestCase):
for param in resnet.parameters():
dy_param_init_value[param.name] = param.numpy()
helper = DyGraphProgramDescTracerTestHelper(resnet, self)
for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num:
break
......@@ -256,7 +259,14 @@ class TestDygraphResnet(unittest.TestCase):
label = data[1]
label.stop_gradient = True
out = resnet(img)
if batch_id % 5 == 0:
out, out_static = helper.run(img,
feed_names=['image'],
fetch_names=['logits'])
helper.assertEachVar(out, out_static)
else:
out = resnet(img)
loss = fluid.layers.cross_entropy(input=out, label=label)
avg_loss = fluid.layers.mean(x=loss)
......
......@@ -24,6 +24,8 @@ import numpy as np
import six
np.set_printoptions(suppress=True)
from utils import DyGraphProgramDescTracerTestHelper
# Copy from models
class TrainTaskConfig(object):
......@@ -976,10 +978,28 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
optimizer = fluid.optimizer.SGD(learning_rate=0.003)
dy_param_init = dict()
dy_param_updated = dict()
helper = DyGraphProgramDescTracerTestHelper(transformer, self)
for i in range(batch_num):
enc_inputs, dec_inputs, label, weights = create_data()
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
enc_inputs, dec_inputs, label, weights)
if i % 5 == 0:
outs, outs_static = helper.run(
inputs=[enc_inputs, dec_inputs, label, weights],
feed_names=[
'enc_input_0', 'enc_input_1', 'enc_input_2',
'dec_input_0', 'dec_input_1', 'dec_input_2',
'dec_input_3', 'label', 'weights'
],
fetch_names=[
'dy_sum_cost', 'dy_avg_cost', 'dy_predict',
'dy_token_num'
])
helper.assertEachVar(outs, outs_static)
else:
outs = transformer(enc_inputs, dec_inputs, label, weights)
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = outs
if i == 0:
for param in transformer.parameters():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.framework import _dygraph_guard
import paddle.fluid as fluid
from paddle.fluid.framework import Variable
import paddle.fluid.dygraph.jit as jit
from paddle.fluid.dygraph.jit import extract_vars
import numpy as np
import os
import time
__all__ = ['DyGraphProgramDescTracerTestHelper', ]
def is_equal_program(prog1, prog2):
with _dygraph_guard(None):
return _is_equal_program(prog1, prog2)
def _is_equal_program(prog1, prog2):
block_num = prog1.num_blocks
if block_num != prog2.num_blocks:
return False
for block_id in range(block_num):
block1 = prog1.block(block_id)
block2 = prog2.block(block_id)
if len(block1.ops) != len(block2.ops):
return False
if len(block1.vars) != len(block2.vars):
return False
for op1, op2 in zip(block1.ops, block2.ops):
if op1.input_arg_names != op2.input_arg_names:
return False
if op1.output_arg_names != op2.output_arg_names:
return False
attr1 = op1.all_attrs()
attr2 = op2.all_attrs()
if len(attr1) != len(attr2):
return False
for key1, value1 in attr1.items():
if key1 not in attr2:
return False
if value1 != attr2.get(key1):
return False
for var1 in block1.vars.values():
if var1.name not in block2.vars:
return False
var2 = block2.vars.get(var1.name)
if var1.name != var2.name:
return False
if var1.type != var2.type:
return False
if var1.dtype != var2.dtype:
return False
if var1.lod_level != var2.lod_level:
return False
if var1.persistable != var2.persistable:
return False
return True
def load_dygraph_vars_to_scope(model_path, scope, place):
def load_dict_to_scope(scope, dictionary):
if scope is None:
scope = fluid.global_scope()
for k, v in dictionary.items():
dst_t = scope.var(k).get_tensor()
src_t = v.value().get_tensor()
dst_t.set(np.array(src_t), place)
dst_t.set_lod(src_t.lod())
param_dict, opti_dict = fluid.load_dygraph(model_path)
if param_dict:
load_dict_to_scope(scope, param_dict)
if opti_dict:
load_dict_to_scope(scope, opti_dict)
class DyGraphProgramDescTracerTestHelper(object):
def __init__(self,
module,
unittest_obj,
model_path=None,
scope=None,
place=None):
self.module = module
self.unittest_obj = unittest_obj
self.scope = fluid.Scope() if scope is None else scope
self.model_path = model_path
if model_path is None:
millis = int(round(time.time() * 1000))
self.model_path = "id_{}_{}".format(id(module), millis)
self.place = place
if place is None:
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.program = None
self.executor = fluid.Executor(self.place)
def _remove_model_path(self):
if os.path.exists(self.model_path + ".pdparams"):
os.remove(self.model_path + ".pdparams")
if os.path.exists(self.model_path + ".pdopt"):
os.remove(self.model_path + ".pdopt")
def _run_static_graph(self, inputs, feed_names, fetch_names):
var_list = extract_vars(inputs)
assert len(var_list) == len(feed_names)
feed_dict = {}
for name, var in zip(feed_names, var_list):
feed_dict[name] = np.array(var.value().get_tensor())
with fluid.scope_guard(self.scope):
with _dygraph_guard(None):
return self.executor.run(self.program,
feed=feed_dict,
fetch_list=fetch_names)
def run(self, inputs, feed_names, fetch_names):
out_dygraph, program = jit.trace(
self.module, inputs, feed_names=feed_names, fetch_names=fetch_names)
if self.program is not None:
self.unittest_obj.assertTrue(
is_equal_program(self.program, program))
self.program = program
fluid.save_dygraph(self.module.state_dict(), self.model_path)
load_dygraph_vars_to_scope(self.model_path, self.scope, self.place)
self._remove_model_path()
out_static_graph = self._run_static_graph(inputs, feed_names,
fetch_names)
if not isinstance(out_dygraph, (list, tuple)):
assert len(out_static_graph) == 1
out_static_graph = out_static_graph[0]
return out_dygraph, out_static_graph
def assertEachVar(self, out_dygraph, out_static_graph, func=None):
if func is None:
func = lambda x, y: np.array_equal(x, y)
if not isinstance(out_dygraph, (list, tuple)):
out_dygraph = [out_dygraph]
if not isinstance(out_static_graph, (list, tuple)):
out_static_graph = [out_static_graph]
for v1, v2 in zip(out_dygraph, out_static_graph):
self.unittest_obj.assertTrue(func(v1.numpy(), v2))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册