提交 bbceb723 编写于 作者: Q qijun

refine some codes

上级 48b080db
......@@ -74,16 +74,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
for (auto& device_context : device_contexts_) {
device_context->Wait();
}
// // print tensor value
// for (auto& var : block.vars()) {
// std::cout << var.name() << std::endl;
// auto v = scope->FindVar(var.name());
// const LoDTensor& t = v->Get<LoDTensor>();
// for (int i = 0; i < t.numel(); ++i) {
// std::cout << t.data<float>()[i] << " ";
// }
// std::cout << std::endl;
// }
}
} // namespace framework
......
......@@ -130,6 +130,7 @@ std::once_flag set_variable_flag;
template <typename T>
void set_feed_variable(const std::vector<std::vector<T>>& inputs) {
typedef std::vector<paddle::framework::Tensor> FeedInputs;
// Tensors in feed value variable will only be in CPUPlace
Variable* g_feed_value = GetScope()->FindVar("feed_value");
FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>());
auto size = inputs.size();
......@@ -144,6 +145,7 @@ void set_feed_variable(const std::vector<std::vector<T>>& inputs) {
template <typename T>
std::vector<std::vector<T>> get_fetch_variable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs;
// Tensors in fetch value variable will only be in CPUPlace
Variable* g_fetch_value = GetScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>());
......
......@@ -66,15 +66,10 @@ void Scope::DropKids() {
std::once_flag feed_variable_flag;
template <typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>();
static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope());
g_scope->NewVar("feed_value");
g_scope->NewVar("fetch_value");
});
......
......@@ -33,7 +33,7 @@ class FeedOp : public framework::OperatorWithKernel {
auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Out", in_dim);
// TODO(qijun) need to handle LodTensor later
// TODO(qijun): need to handle LodTensor later
}
framework::DataType IndicateDataType(
......
......@@ -39,7 +39,7 @@ class FetchOp : public framework::OperatorWithKernel {
tmp.Resize(input_dim);
(*tensors)[col].Resize(input_dim);
// TODO(qijun) need to handle LodTensor later
// TODO(qijun): need to handle LodTensor later
}
framework::DataType IndicateDataType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册