未验证 提交 21440b4d 编写于 作者: C chengduo 提交者: GitHub

Add call stack info during compile time (#19067)

* Add call stack info during runtime and compile time
test=develop

* Rename operator_call_stack
test=develop

* Add unit test
test=develop

* follow comment
test=develop
上级 a99bc64c
...@@ -124,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co ...@@ -124,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -135,6 +135,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc ...@@ -135,6 +135,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_library(op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_call_stack.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
platform::EnforceNotMet *exception) {
if (attrs.count("sub_block") != 0) {
return;
}
auto &callstack = boost::get<std::vector<std::string>>(
attrs.at(OpProtoAndCheckerMaker::OpCreationCallstackAttrName()));
if (callstack.empty()) {
return;
}
std::ostringstream sout;
sout << "Invoke operator " << type << " error.\n";
sout << "Python Call stacks: \n";
for (auto &line : callstack) {
sout << line;
}
sout << "C++ Call stacks: \n";
sout << exception->err_str_;
exception->err_str_ = sout.str();
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
platform::EnforceNotMet *exception);
} // namespace framework
} // namespace paddle
...@@ -18,8 +18,10 @@ limitations under the License. */ ...@@ -18,8 +18,10 @@ limitations under the License. */
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -679,6 +681,7 @@ void OpDesc::CheckAttrs() { ...@@ -679,6 +681,7 @@ void OpDesc::CheckAttrs() {
} }
void OpDesc::InferShape(const BlockDesc &block) const { void OpDesc::InferShape(const BlockDesc &block) const {
try {
VLOG(3) << "CompileTime infer shape on " << Type(); VLOG(3) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs(); InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
...@@ -699,6 +702,12 @@ void OpDesc::InferShape(const BlockDesc &block) const { ...@@ -699,6 +702,12 @@ void OpDesc::InferShape(const BlockDesc &block) const {
VLOG(10) << sout.str(); VLOG(10) << sout.str();
} }
infer_shape(&ctx); infer_shape(&ctx);
} catch (platform::EnforceNotMet exception) {
framework::InsertCallStackInfo(Type(), attrs_, &exception);
throw std::move(exception);
} catch (...) {
std::rethrow_exception(std::current_exception());
}
} }
void OpDesc::InferVarType(BlockDesc *block) const { void OpDesc::InferVarType(BlockDesc *block) const {
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
...@@ -186,28 +187,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -186,28 +187,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
} else { } else {
RunImpl(scope, place); RunImpl(scope, place);
} }
VLOG(3) << place << " " << DebugStringEx(&scope); VLOG(3) << place << " " << DebugStringEx(&scope);
} catch (platform::EnforceNotMet exception) { } catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) { framework::InsertCallStackInfo(Type(), Attrs(), &exception);
throw std::move(exception);
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw std::move(exception);
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw std::move(exception); throw std::move(exception);
} catch (...) { } catch (...) {
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestRunTimeException(OpTest):
def test_run_time_exception(self):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
fluid.layers.one_hot(input=label, depth=100)
def _run_program():
x = np.random.random(size=(10)).astype('int64')
exe.run(train_program, feed={"label": x})
self.assertRaises(core.EnforceNotMet, _run_program)
class TestCompileTimeException(OpTest):
def test_compile_time_exception(self):
self.assertRaises(core.EnforceNotMet, self.build_model)
def build_model(self):
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(
name="label", shape=[1], dtype="int64", append_batch_size=False)
fluid.layers.one_hot(input=label, depth=100)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册