提交 1cd61122 编写于 作者: Z zlsh80826

Merge branch 'fix_stack_op_conflict' into trt_stack_opi, test=develop

...@@ -58,7 +58,6 @@ namespace framework { ...@@ -58,7 +58,6 @@ namespace framework {
std::once_flag gflags_init_flag; std::once_flag gflags_init_flag;
std::once_flag glog_init_flag; std::once_flag glog_init_flag;
std::once_flag p2p_init_flag; std::once_flag p2p_init_flag;
std::once_flag glog_warning_once_flag;
bool InitGflags(std::vector<std::string> args) { bool InitGflags(std::vector<std::string> args) {
bool successed = false; bool successed = false;
...@@ -260,22 +259,22 @@ const char *ParseSignalErrorString(const std::string &str) { ...@@ -260,22 +259,22 @@ const char *ParseSignalErrorString(const std::string &str) {
} }
// Handle SIGSEGV, SIGILL, SIGFPE, SIGABRT, SIGBUS, and SIGTERM. // Handle SIGSEGV, SIGILL, SIGFPE, SIGABRT, SIGBUS, and SIGTERM.
std::ostringstream signal_msg_dumper;
void SignalHandle(const char *data, int size) { void SignalHandle(const char *data, int size) {
try { try {
// NOTE1: The glog FailureSignalHandler dumped messages // NOTE1: The glog FailureSignalHandler dumped messages
// are deal with line by line // are deal with line by line
auto signal_msg_dunmer_ptr = SignalMessageDumper::Instance().Get();
// NOTE2: we only deal with the time info ane signal info, // NOTE2: we only deal with the time info ane signal info,
// the stack trace will generated by paddle self // the stack trace will generated by paddle self
if (StartsWith(data, "*** Aborted at")) { if (StartsWith(data, "*** Aborted at")) {
signal_msg_dumper << " [TimeInfo: " << std::string(data, size - 1) *signal_msg_dunmer_ptr << " [TimeInfo: " << std::string(data, size - 1)
<< "]\n"; << "]\n";
} else if (StartsWith(data, "***")) { } else if (StartsWith(data, "***")) {
std::string signal_info(data, size - 1); std::string signal_info(data, size - 1);
std::string useless_substr("; stack trace:"); std::string useless_substr("; stack trace:");
size_t start_pos = signal_info.rfind(useless_substr); size_t start_pos = signal_info.rfind(useless_substr);
signal_info.replace(start_pos, useless_substr.length(), ""); signal_info.replace(start_pos, useless_substr.length(), "");
signal_msg_dumper << " [SignalInfo: " << signal_info << "]\n"; *signal_msg_dunmer_ptr << " [SignalInfo: " << signal_info << "]\n";
// NOTE3: Here does not throw an exception, // NOTE3: Here does not throw an exception,
// otherwise it will casue "terminate called recursively" // otherwise it will casue "terminate called recursively"
auto exp = platform::EnforceNotMet( auto exp = platform::EnforceNotMet(
...@@ -283,7 +282,7 @@ void SignalHandle(const char *data, int size) { ...@@ -283,7 +282,7 @@ void SignalHandle(const char *data, int size) {
"A serious error (%s) is detected by the operating system.", "A serious error (%s) is detected by the operating system.",
ParseSignalErrorString(signal_info)), ParseSignalErrorString(signal_info)),
__FILE__, __LINE__); __FILE__, __LINE__);
std::cout << exp.what() << signal_msg_dumper.str() << std::endl; std::cout << exp.what() << (*signal_msg_dunmer_ptr).str() << std::endl;
} }
} catch (...) { } catch (...) {
// Since the program has already triggered a system error, // Since the program has already triggered a system error,
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -22,7 +23,7 @@ limitations under the License. */ ...@@ -22,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
void ParseCommandLineFlags(int argc, char **argv, bool remove); void ParseCommandLineFlags(int argc, char** argv, bool remove);
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -32,14 +33,32 @@ namespace framework { ...@@ -32,14 +33,32 @@ namespace framework {
bool InitGflags(std::vector<std::string> argv); bool InitGflags(std::vector<std::string> argv);
void InitGLOG(const std::string &prog_name); void InitGLOG(const std::string& prog_name);
void InitDevices(bool init_p2p); void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices); void InitDevices(bool init_p2p, const std::vector<int> devices);
#ifndef _WIN32 #ifndef _WIN32
void SignalHandle(const char *data, int size); class SignalMessageDumper {
public:
~SignalMessageDumper() {}
SignalMessageDumper(const SignalMessageDumper& o) = delete;
const SignalMessageDumper& operator=(const SignalMessageDumper& o) = delete;
static SignalMessageDumper& Instance() {
static SignalMessageDumper instance;
return instance;
}
std::shared_ptr<std::ostringstream> Get() { return dumper_; }
private:
SignalMessageDumper() : dumper_(new std::ostringstream()) {}
std::shared_ptr<std::ostringstream> dumper_;
};
void SignalHandle(const char* data, int size);
#endif #endif
} // namespace framework } // namespace framework
......
...@@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer): ...@@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer):
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
self._infer_program = self._verify_program(main_program) main_program = self._verify_program(main_program)
self._train_program = self._append_backward_desc() self._infer_program = self._clone_for_test(main_program)
# Switch infer or train by train() and eval() self._train_program = self._append_backward_desc(main_program)
self._trace_program = None
self._set_grad_type(self._params) self._set_grad_type(self._params)
self._inner_scope = core.Scope() self._inner_scope = core.Scope()
# Set default mode to train # Set default mode to train
self.train() self.training = True
def _verify_program(self, main_program): def _verify_program(self, main_program):
""" """
...@@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer): ...@@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer):
return main_program return main_program
@switch_to_static_graph @switch_to_static_graph
def _append_backward_desc(self): def _append_backward_desc(self, main_program):
program = self._infer_program.clone() program = main_program.clone()
targets = [] targets = []
for out in self._outputs.tolist(): for out in self._outputs.tolist():
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
...@@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer): ...@@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer):
self._params = required_params self._params = required_params
def train(self):
# self.training is inherited from layers.Layer
self.training = True
self._trace_program = self._train_program
def eval(self):
self.training = False
self._trace_program = self._infer_program
def forward(self, inputs): def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
...@@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer): ...@@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer):
outputs={'Out': valid_vars(out_vars), outputs={'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec}, 'OutScope': tmp_scope_vec},
attrs={ attrs={
'global_block': self._trace_program.desc.block(0), 'global_block': self.program.desc.block(0),
'start_op_index': 0, 'start_op_index': 0,
'end_op_index': self._infer_program.desc.block(0).op_size(), 'end_op_index': self._infer_program.desc.block(0).op_size(),
'is_test': not self.training 'is_test': not self.training
...@@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer): ...@@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer):
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
@property
def program(self):
return self._train_program if self.training else self._infer_program
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
Prepare inputs, outputs, attrs. Prepare inputs, outputs, attrs.
...@@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer): ...@@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer):
return outs return outs
@switch_to_static_graph
def _clone_for_test(self, main_program):
return main_program.clone(for_test=True)
def _is_no_value(self, var): def _is_no_value(self, var):
if isinstance(var, core.VarBase): if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
......
...@@ -487,6 +487,8 @@ class ProgramTranslator(object): ...@@ -487,6 +487,8 @@ class ProgramTranslator(object):
_, partial_program_layer = self._program_cache[function_spec] _, partial_program_layer = self._program_cache[function_spec]
if args and isinstance(args[0], layers.Layer): if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:] args = args[1:]
return partial_program_layer(args) return partial_program_layer(args)
......
...@@ -16,7 +16,9 @@ from __future__ import print_function ...@@ -16,7 +16,9 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import declarative, ProgramTranslator
from test_fetch_feed import Linear
import unittest import unittest
...@@ -121,5 +123,33 @@ class TestWithNestedOutput(unittest.TestCase): ...@@ -121,5 +123,33 @@ class TestWithNestedOutput(unittest.TestCase):
self.assertTrue(dy_var, st_var) self.assertTrue(dy_var, st_var)
class TestWithTrainAndEval(unittest.TestCase):
def test_switch_eval_and_train(self):
program_translator = ProgramTranslator()
with fluid.dygraph.guard():
linear_net = Linear()
x_data = np.random.random((4, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data)
linear_net(x)
_, partial_layer = program_translator.get_program_cache().last()[-1]
# check default mode is for training
self.assertEqual(partial_layer.program,
partial_layer._train_program)
# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._infer_program)
# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._train_program)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册