未验证 提交 1226cc01 编写于 作者: W wanghuancoder 提交者: GitHub

runtimecontext (#33608)

* runtimecontext

* ExecutionContextV2

* refine

* refine

* pass test
上级 bc8a8042
...@@ -409,7 +409,7 @@ cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framewo ...@@ -409,7 +409,7 @@ cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framewo
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler place)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gperftools/profiler.h>
#include <chrono>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include <chrono> namespace paddle {
#include <gperftools/profiler.h> namespace framework {
class RuntimeContextV2 {
public:
RuntimeContextV2(std::vector<std::vector<Variable*>>& in_values, // NOLINT
std::vector<std::vector<Variable*>>& out_values, // NOLINT
const std::map<std::string, size_t>& in_name_map,
const std::map<std::string, size_t>& out_name_map)
: input_values(std::move(in_values)),
output_values(std::move(out_values)),
input_name_map(in_name_map),
output_name_map(out_name_map) {}
std::vector<std::vector<Variable*>> input_values;
std::vector<std::vector<Variable*>> output_values;
const std::map<std::string, size_t>& input_name_map;
const std::map<std::string, size_t>& output_name_map;
};
//USE_OP(fill_constant); class ExecutionContextV2 : public ExecutionContext {
//USE_OP(elementwise_add); public:
ExecutionContextV2(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context,
const RuntimeContextV2& ctx)
: ExecutionContext(op, scope, device_context, RuntimeContext({}, {})),
ctx_(ctx) {}
using namespace std; const std::vector<Variable*> MultiInputVar(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
namespace paddle { auto it = ctx_.input_name_map.find(name);
namespace framework { if (it == ctx_.input_name_map.end()) {
return {};
}
// return {it->second.begin(), it->second.end()};
return ctx_.input_values[it->second];
}
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
auto it = ctx_.output_name_map.find(name);
if (it == ctx_.output_name_map.end()) {
return {};
}
// return it->second;
return ctx_.output_values[it->second];
}
std::vector<std::string> InNameList() const {
std::vector<std::string> vec_temp;
vec_temp.reserve(ctx_.output_name_map.size());
for (auto& input : ctx_.output_name_map) {
vec_temp.push_back(input.first);
}
return vec_temp;
}
const Variable* InputVar(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
auto it = ctx_.input_name_map.find(name);
if (it == ctx_.input_name_map.end()) return nullptr;
PADDLE_ENFORCE_LE(
ctx_.input_values[it->second].size(), 1UL,
platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.",
GetOp().Type(), name));
return ctx_.input_values[it->second].empty()
? nullptr
: ctx_.input_values[it->second][0];
}
Variable* OutputVar(const std::string& name) const {
auto it = ctx_.output_name_map.find(name);
if (it == ctx_.output_name_map.end()) return nullptr;
PADDLE_ENFORCE_LE(
ctx_.output_values[it->second].size(), 1UL,
platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.",
GetOp().Type(), name));
return ctx_.output_values[it->second].empty()
? nullptr
: ctx_.output_values[it->second][0];
}
const RuntimeContextV2& ctx_;
};
class RuntimeInferShapeContext : public InferShapeContext { class RuntimeInferShapeContext : public InferShapeContext {
public: public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContextV2& ctx)
: op_(op), ctx_(ctx) {} : op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
const auto& ins = ctx_.inputs; const auto& ins = ctx_.input_name_map;
auto it = ins.find(name); auto it = ins.find(name);
if (it == ins.end()) { if (it == ins.end()) {
return false; return false;
} }
const auto& in = it->second; const auto& in = ctx_.input_values[it->second];
if (in.size() == 0) return false; if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in.size(), 1UL, in.size(), 1UL,
...@@ -55,12 +151,12 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -55,12 +151,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
// has only one output // has only one output
const auto& outs = ctx_.outputs; const auto& outs = ctx_.output_name_map;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end()) { if (it == outs.end()) {
return false; return false;
} }
const auto& out = it->second; const auto& out = ctx_.output_values[it->second];
if (out.size() == 0) { if (out.size() == 0) {
return false; return false;
} }
...@@ -72,12 +168,12 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -72,12 +168,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs; const auto& ins = ctx_.input_name_map;
auto it = ins.find(name); auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) { if (it == ins.end() || ctx_.input_values[it->second].empty()) {
return false; return false;
} }
for (auto& input : it->second) { for (auto& input : ctx_.input_values[it->second]) {
if (input == nullptr) { if (input == nullptr) {
return false; return false;
} }
...@@ -86,12 +182,12 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -86,12 +182,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
const auto& outs = ctx_.outputs; const auto& outs = ctx_.output_name_map;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) { if (it == outs.end() || ctx_.output_values[it->second].empty()) {
return false; return false;
} }
for (auto& output : it->second) { for (auto& output : ctx_.output_values[it->second]) {
if (output == nullptr) { if (output == nullptr) {
return false; return false;
} }
...@@ -134,27 +230,27 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -134,27 +230,27 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareDim(const std::string& in, const std::string& out, size_t i = 0, void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override { size_t j = 0) override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.input_name_map.find(in);
auto out_it = ctx_.outputs.find(out); auto out_it = ctx_.output_name_map.find(out);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in_it, ctx_.inputs.end(), in_it, ctx_.input_name_map.end(),
platform::errors::NotFound("Input %s does not exist.", in)); platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(), out_it, ctx_.output_name_map.end(),
platform::errors::NotFound("Output %s does not exist.", out)); platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(), PADDLE_ENFORCE_LT(i, ctx_.input_values[in_it->second].size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The index of input dimension is out of range, " "The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.", "excepted index less than %zu, but received %zu.",
in_it->second.size(), i)); ctx_.input_values[in_it->second].size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(), PADDLE_ENFORCE_LT(j, ctx_.output_values[out_it->second].size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The index of output dimension is out of range, " "The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.", "excepted index less than %zu, but received %zu.",
out_it->second.size(), j)); ctx_.output_values[out_it->second].size(), j));
Variable* in_var = in_it->second[i]; Variable* in_var = ctx_.input_values[in_it->second][i];
Variable* out_var = out_it->second[j]; Variable* out_var = ctx_.output_values[out_it->second][j];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var->Type(), out_var->Type(), in_var->Type(), out_var->Type(),
...@@ -181,18 +277,18 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -181,18 +277,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareAllLoD(const std::string& in, void ShareAllLoD(const std::string& in,
const std::string& out) const override { const std::string& out) const override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.input_name_map.find(in);
auto out_it = ctx_.outputs.find(out); auto out_it = ctx_.output_name_map.find(out);
PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(), PADDLE_ENFORCE_NE(in_it, ctx_.input_name_map.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type())); "Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(), out_it, ctx_.output_name_map.end(),
platform::errors::NotFound("Output [%s] found error in Op [%s]", out, platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
op_.Type())); op_.Type()));
auto& in_var_list = in_it->second; auto& in_var_list = ctx_.input_values[in_it->second];
auto& out_var_list = out_it->second; auto& out_var_list = ctx_.output_values[out_it->second];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_list.size(), out_var_list.size(), in_var_list.size(), out_var_list.size(),
...@@ -226,28 +322,28 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -226,28 +322,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.input_name_map.find(in);
auto out_it = ctx_.outputs.find(out); auto out_it = ctx_.output_name_map.find(out);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in_it, ctx_.inputs.end(), in_it, ctx_.input_name_map.end(),
platform::errors::NotFound("Input %s does not exist.", in)); platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(), out_it, ctx_.output_name_map.end(),
platform::errors::NotFound("Output %s does not exist.", out)); platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(), PADDLE_ENFORCE_LT(i, ctx_.input_values[in_it->second].size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The index of input dimension is out of range, " "The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.", "excepted index less than %zu, but received %zu.",
in_it->second.size(), i)); ctx_.input_values[in_it->second].size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(), PADDLE_ENFORCE_LT(j, ctx_.output_values[out_it->second].size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The index of output dimension is out of range, " "The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.", "excepted index less than %zu, but received %zu.",
out_it->second.size(), j)); ctx_.output_values[out_it->second].size(), j));
Variable* in_var = in_it->second.at(i); Variable* in_var = ctx_.input_values[in_it->second].at(i);
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_it->second.at(j); Variable* out_var = ctx_.output_values[out_it->second].at(j);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_var->IsType<LoDTensor>(), true, out_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -339,7 +435,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -339,7 +435,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
void SetOutputDim(const std::string& name, const DDim& dim) override { void SetOutputDim(const std::string& name, const DDim& dim) override {
//cerr << "set out dim" << endl; // std::cerr << "set out dim" << std::endl;
auto& vars = OutputVars(name); auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
vars.size(), 1UL, vars.size(), 1UL,
...@@ -385,9 +481,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -385,9 +481,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
void SetDim(Variable* var, const DDim& dim) { void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
...@@ -438,114 +532,106 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -438,114 +532,106 @@ class RuntimeInferShapeContext : public InferShapeContext {
private: private:
const std::vector<Variable*>& InputVars(const std::string& name) const { const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name); auto it = ctx_.input_name_map.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, ctx_.inputs.end(), it, ctx_.input_name_map.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name)); "Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second; return ctx_.input_values[it->second];
} }
const std::vector<Variable*>& OutputVars(const std::string& name) const { const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name); auto it = ctx_.output_name_map.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, ctx_.outputs.end(), it, ctx_.output_name_map.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name)); "Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second; return ctx_.output_values[it->second];
} }
const OperatorBase& op_; const OperatorBase& op_;
const RuntimeContext& ctx_; const RuntimeContextV2& ctx_;
}; };
framework::ProgramDesc load_from_file(const std::string& file_name) {
framework::ProgramDesc load_from_file( const std::string& file_name )
{
std::ifstream fin(file_name, std::ios::in | std::ios::binary); std::ifstream fin(file_name, std::ios::in | std::ios::binary);
if (!fin.is_open()) {
std::cout << "open file " << file_name << " faild!" << std::endl;
}
fin.seekg(0, std::ios::end); fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' '); std::string buffer(fin.tellg(), ' ');
fin.seekg(0, std::ios::beg); fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size()); fin.read(&buffer[0], buffer.size());
fin.close(); fin.close();
ProgramDesc program_desc(buffer);
ProgramDesc program_desc( buffer );
return program_desc; return program_desc;
} }
struct VariableScope {
struct VariableScope std::vector<std::unique_ptr<Variable>> var_list;
{ std::map<std::string, size_t> name2id;
std::vector< std::unique_ptr<Variable> > var_list;
std::map<std::string, int> name2id;
}; };
struct OpFuncNode {
// int unsed;
// std::map< std::string, std::vector<int> > input_index;
struct OpFuncNode{ // std::map< std::string, std::vector<int> > output_index;
std::vector<std::vector<size_t>> input_index;
//int unsed; std::vector<std::vector<size_t>> output_index;
std::map< std::string, std::vector<int> > input_index; std::map<std::string, size_t> input_name_map;
std::map< std::string, std::vector<int> > output_index; std::map<std::string, size_t> output_name_map;
using OpKernelFunc = std::function<void(const ExecutionContext&)>; using OpKernelFunc = std::function<void(const ExecutionContext&)>;
OpKernelFunc kernel_func_; OpKernelFunc kernel_func_;
}; };
int convert(const platform::Place& place ) int convert(const platform::Place& place) {
{ if (is_cpu_place(place)) {
if ( is_cpu_place(place ))
{
return 0; return 0;
} }
if( is_gpu_place( place )) if (is_gpu_place(place)) {
{
return 1; return 1;
} }
return -1; return -1;
} }
void build_variable_scope( const framework::ProgramDesc& pdesc, VariableScope* var_scope ) void build_variable_scope(const framework::ProgramDesc& pdesc,
{ VariableScope* var_scope) {
auto& global_block = pdesc.Block(0); auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) { if (var->Name() == framework::kEmptyVarName) {
continue; continue;
} }
//cerr << "var name " << var->Name() << endl; // std::cerr << "var name " << var->Name() << std::endl;
if ( var_scope->name2id.find( var->Name() ) == var_scope->name2id.end() ) if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) {
{ var_scope->name2id[var->Name()] = var_scope->var_list.size();
var_scope->name2id[ var->Name() ] = var_scope->var_list.size();
} }
auto v = new Variable(); auto v = new Variable();
//v->GetMutable<LoDTensor>(); // v->GetMutable<LoDTensor>();
InitializeVariable(v, var->GetType()); InitializeVariable(v, var->GetType());
var_scope->var_list.push_back(std::unique_ptr<Variable>(v)); var_scope->var_list.push_back(std::unique_ptr<Variable>(v));
} }
} }
void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<OperatorBase* >& op_list, void build_op_func_list(const framework::ProgramDesc& pdesc,
std::vector<OpFuncNode>& vec_func_list, VariableScope* var_scope, std::vector<OperatorBase*>& op_list, // NOLINT
const platform::Place& place ) std::vector<OpFuncNode>& vec_func_list, // NOLINT
{ VariableScope* var_scope,
auto &global_block = pdesc.Block( 0 ); const platform::Place& place) {
auto& global_block = pdesc.Block(0);
for ( auto& op : global_block.AllOps() ) for (auto& op : global_block.AllOps()) {
{ // std::cerr << op->Type() << std::endl;
//cerr << op->Type() << endl; // bool debug = op->Type() == "softmax_with_cross_entropy_grad";
//bool debug = op->Type() == "softmax_with_cross_entropy_grad";
bool debug = false; bool debug = false;
//cerr << "create op" << endl; // std::cerr << "create op" << std::endl;
//auto op_base_u = OpRegistry::CreateOp(*op); // auto op_base_u = OpRegistry::CreateOp(*op);
auto& info = OpInfoMap::Instance().Get( op->Type() ); auto& info = OpInfoMap::Instance().Get(op->Type());
VariableNameMap inputs_1 = op->Inputs(); VariableNameMap inputs_1 = op->Inputs();
VariableNameMap outputs_1 = op->Outputs(); VariableNameMap outputs_1 = op->Outputs();
...@@ -554,368 +640,407 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat ...@@ -554,368 +640,407 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_1); info.Checker()->Check(&attrs_1);
} }
auto op_base = info.Creator()( op->Type(), inputs_1, outputs_1, attrs_1); auto op_base = info.Creator()(op->Type(), inputs_1, outputs_1, attrs_1);
auto input_names = op->Inputs(); auto input_names = op->Inputs();
auto output_names = op->Outputs(); auto output_names = op->Outputs();
OpFuncNode op_func_node; OpFuncNode op_func_node;
VariableValueMap ins_map; // VariableValueMap ins_map;
std::map< std::string, std::vector<int> > ins_name2id; // std::map<std::string, std::vector<int> > ins_name2id;
for( auto& var_name_item : input_names) std::vector<std::vector<Variable*>> ins_value;
{ std::vector<std::vector<size_t>> ins_index;
std::map<std::string, size_t> ins_name_map;
for (auto& var_name_item : input_names) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
std::vector<int> vec_ids; std::vector<size_t> vec_ids;
input_vars.reserve(var_name_item.second.size()); input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto it = var_scope->name2id.find( var_name ); auto it = var_scope->name2id.find(var_name);
assert( it != var_scope->name2id.end() ); assert(it != var_scope->name2id.end());
input_vars.push_back( var_scope->var_list[ it->second].get()); input_vars.push_back(var_scope->var_list[it->second].get());
vec_ids.push_back( it->second ); vec_ids.push_back(it->second);
} }
ins_map[ var_name_item.first ] = input_vars; ins_value.emplace_back(std::move(input_vars));
ins_name2id[ var_name_item.first ] = vec_ids; ins_index.emplace_back(std::move(vec_ids));
ins_name_map[var_name_item.first] = ins_index.size() - 1;
} // ins_map[ var_name_item.first ] = input_vars;
if (debug ) cerr << "1" << endl; // ins_name2id[ var_name_item.first ] = vec_ids;
}
if (debug) std::cerr << "1" << std::endl;
VariableValueMap outs_map;
std::map<std::string, std::vector<int> > outs_name2id; // VariableValueMap outs_map;
for( auto& var_name_item : output_names ) // std::map<std::string, std::vector<int> > outs_name2id;
{ std::vector<std::vector<Variable*>> outs_value;
std::vector<std::vector<size_t>> outs_index;
std::map<std::string, size_t> outs_name_map;
for (auto& var_name_item : output_names) {
std::vector<Variable*> output_vars; std::vector<Variable*> output_vars;
std::vector<int> vec_ids; std::vector<size_t> vec_ids;
output_vars.reserve(var_name_item.second.size()); output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto it = var_scope->name2id.find( var_name ); auto it = var_scope->name2id.find(var_name);
assert( it != var_scope->name2id.end() ); assert(it != var_scope->name2id.end());
//cerr << it->second << "\t" << var_scope.var_list.size() << endl; // std::cerr << it->second << "\t" << var_scope.var_list.size() <<
output_vars.push_back( var_scope->var_list[ it->second].get() ); // std::endl;
vec_ids.push_back( it->second ); output_vars.push_back(var_scope->var_list[it->second].get());
} vec_ids.push_back(it->second);
outs_map[ var_name_item.first ] = output_vars; }
//cerr << ToTypeName(output_vars[0]->Type() ) << endl; outs_value.emplace_back(std::move(output_vars));
outs_name2id[ var_name_item.first ] = vec_ids; outs_index.emplace_back(std::move(vec_ids));
} outs_name_map[var_name_item.first] = outs_index.size() - 1;
// outs_map[ var_name_item.first ] = output_vars;
// //std::cerr << ToTypeName(output_vars[0]->Type() ) << std::endl;
op_func_node.input_index = ins_name2id; // outs_name2id[ var_name_item.first ] = vec_ids;
op_func_node.output_index = outs_name2id; }
RuntimeContext runtime_context( {}, {});
runtime_context.inputs.swap( ins_map ); // op_func_node.input_index = ins_name2id;
runtime_context.outputs.swap( outs_map ); // op_func_node.output_index = outs_name2id;
//cerr << "create runtime context" << endl; op_func_node.input_index = ins_index;
op_func_node.input_name_map = ins_name_map;
op_func_node.output_index = outs_index;
op_func_node.output_name_map = outs_name_map;
RuntimeContextV2 runtime_context(ins_value, outs_value, ins_name_map,
outs_name_map);
// runtime_context.inputs.swap( ins_map );
// runtime_context.outputs.swap( outs_map );
// runtime_context.input_values.swap(ins_value);
// runtime_context.input_name_map = ins_name_map;
// runtime_context.output_values.swap(outs_value);
// runtime_context.output_name_map = outs_name_map;
// std::cerr << "create runtime context" << std::endl;
RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context); RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context);
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape( &infer_shape_ctx ); static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
//cerr << "fin infer shape" << endl; &infer_shape_ctx);
// std::cerr << "fin infer shape" << std::endl;
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
auto kernels_iter = all_op_kernels.find(op->Type() ); auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(), kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable( platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.",
op->Type() )); op->Type()));
//cerr << "create kernel" << endl; // std::cerr << "create kernel" << std::endl;
using OpKernelFunc = std::function<void(const ExecutionContext&)>; using OpKernelFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>; std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
if (debug ) cerr << "2" << endl; if (debug) std::cerr << "2" << std::endl;
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
//auto place = platform::CPUPlace(); // auto place = platform::CPUPlace();
//auto place = platform::CUDAPlace(0); // auto place = platform::CUDAPlace(0);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto exec_ctx = ExecutionContext(*op_base, scope, *dev_ctx, runtime_context ); auto exec_ctx =
if (debug ) cerr << "21" << endl; ExecutionContextV2(*op_base, scope, *dev_ctx, runtime_context);
auto expected_kernel_key = dynamic_cast<const framework::OperatorWithKernel*>(op_base)->GetExpectedKernelType( exec_ctx ); if (debug) std::cerr << "21" << std::endl;
if (debug ) cerr << "22" << endl; auto expected_kernel_key =
//cerr << "22" << endl; dynamic_cast<const framework::OperatorWithKernel*>(op_base)
->GetExpectedKernelType(exec_ctx);
if (debug) std::cerr << "22" << std::endl;
// std::cerr << "22" << std::endl;
// add transfer log // add transfer log
//cerr << "in map size " << ins_map.size() << endl; // std::cerr << "in map size " << ins_map.size() << std::endl;
VariableValueMap& ins_map_temp = runtime_context.inputs; // VariableValueMap& ins_map_temp = runtime_context.inputs;
//cerr << "ins map siz" << ins_map_temp.size() << endl; auto ins_map_temp = runtime_context.input_name_map;
for( auto& var_name_item : ins_map_temp ) // std::cerr << "ins map siz" << ins_map_temp.size() << std::endl;
{ for (auto& var_name_item : ins_map_temp) {
// std::cerr << "in name " << var_name_item.first << std::endl;
//auto& vec_ids = ins_name2id[ var_name_item.first ]; // auto& vec_ids = ins_name2id[ var_name_item.first ];
for( size_t i = 0; i < var_name_item.second.size(); ++i ) for (size_t i = 0;
{ i < runtime_context.input_values[var_name_item.second].size(); ++i) {
auto var = var_name_item.second[i]; auto var = runtime_context.input_values[var_name_item.second][i];
auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>())); auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
if( !tensor_in->IsInitialized() ) if (!tensor_in->IsInitialized()) {
{
continue; continue;
} }
//cerr << "i " << i << "\t" << tensor_in->IsInitialized() << endl; // std::cerr << "i " << i << "\t" << tensor_in->IsInitialized() <<
auto kernel_type_for_var = static_cast<const framework::OperatorWithKernel*>(op_base)->GetKernelTypeForVar( // std::endl;
var_name_item.first, *tensor_in, expected_kernel_key); auto kernel_type_for_var =
if( debug) static_cast<const framework::OperatorWithKernel*>(op_base)
{ ->GetKernelTypeForVar(var_name_item.first, *tensor_in,
cerr << "var name " << var_name_item.first << endl; expected_kernel_key);
cerr << expected_kernel_key.place_ << "\t" << kernel_type_for_var.place_ << endl; if (debug) {
} std::cerr << "var name " << var_name_item.first << std::endl;
if ( !platform::is_same_place(kernel_type_for_var.place_, std::cerr << expected_kernel_key.place_ << "\t"
expected_kernel_key.place_) ) << kernel_type_for_var.place_ << std::endl;
{ }
if(debug) cerr << "add data transfer" << endl; if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) {
if (debug) std::cerr << "add data transfer" << std::endl;
// need trans place // need trans place
// add var in scope // add var in scope
// add copy op // add copy op
std::string new_var_name = "temp_1" + to_string( var_scope->var_list.size() + 1); std::string new_var_name =
"temp_1" + std::to_string(var_scope->var_list.size() + 1);
auto v = new Variable(); auto v = new Variable();
v->GetMutable<LoDTensor>(); v->GetMutable<LoDTensor>();
var_scope->name2id[ new_var_name ] = var_scope->var_list.size(); var_scope->name2id[new_var_name] = var_scope->var_list.size();
var_scope->var_list.push_back(std::unique_ptr<Variable>(v)); var_scope->var_list.push_back(std::unique_ptr<Variable>(v));
VariableNameMap copy_in_map; VariableNameMap copy_in_map;
//cerr << "ints name is " << input_names[var_name_item.first][i] << endl; // std::cerr << "ints name is " << input_names[var_name_item.first][i]
copy_in_map["X"] = { input_names[var_name_item.first][i] }; // << std::endl;
copy_in_map["X"] = {input_names[var_name_item.first][i]};
VariableNameMap copy_out_map; VariableNameMap copy_out_map;
copy_out_map["Out"] = { new_var_name }; copy_out_map["Out"] = {new_var_name};
AttributeMap attr_map; AttributeMap attr_map;
attr_map["dst_place_type"] = convert( place ); attr_map["dst_place_type"] = convert(place);
std::map< std::string, std::vector<int> > copy_ins_name2id; // std::map< std::string, std::vector<int> > copy_ins_name2id;
copy_ins_name2id["X"] = ins_name2id[ var_name_item.first ]; // copy_ins_name2id["X"] = ins_name2id[ var_name_item.first ];
std::map< std::string, std::vector<int> > copy_out_name2id; // std::map< std::string, std::vector<int> > copy_out_name2id;
copy_out_name2id["Out"] = { var_scope->name2id[new_var_name]}; // copy_out_name2id["Out"] = { var_scope->name2id[new_var_name]};
//vec_ids[i] = var_scope->name2id[new_var_name]; // vec_ids[i] = var_scope->name2id[new_var_name];
// update out runtime_context // update out runtime_context
op_func_node.input_index[ var_name_item.first ][i] = var_scope->name2id[new_var_name]; op_func_node
.input_index[op_func_node.input_name_map[var_name_item.first]]
VariableValueMap copy_ins_value_map; [i] = var_scope->name2id[new_var_name];
copy_ins_value_map["X"] = { var };
VariableValueMap copy_outs_value_map; // VariableValueMap copy_ins_value_map;
copy_outs_value_map["Out"] = { v }; // copy_ins_value_map["X"] = { var };
// VariableValueMap copy_outs_value_map;
// copy_outs_value_map["Out"] = { v };
auto& copy_info = OpInfoMap::Instance().Get( "memcpy" ); auto& copy_info = OpInfoMap::Instance().Get("memcpy");
auto copy_op = copy_info.Creator()( "memcpy", copy_in_map, copy_out_map, attr_map); auto copy_op = copy_info.Creator()("memcpy", copy_in_map,
if(debug) cerr << "create memcpy" << endl; copy_out_map, attr_map);
if (debug) std::cerr << "create memcpy" << std::endl;
OpFuncNode copy_op_func_node; OpFuncNode copy_op_func_node;
copy_op_func_node.input_index = copy_ins_name2id; // copy_op_func_node.input_index = copy_ins_name2id;
copy_op_func_node.output_index = copy_out_name2id; // copy_op_func_node.output_index = copy_out_name2id;
copy_op_func_node.input_index.push_back(
RuntimeContext copy_runtime_context( {}, {}); ins_index[ins_name_map[var_name_item.first]]);
copy_runtime_context.inputs.swap( copy_ins_value_map ); copy_op_func_node.input_name_map["X"] = 0;
copy_runtime_context.outputs.swap( copy_outs_value_map ); copy_op_func_node.output_index.push_back(
//cerr << "create runtime context" << endl; {var_scope->name2id[new_var_name]});
RuntimeInferShapeContext copy_infer_shape_ctx(*copy_op, copy_runtime_context); copy_op_func_node.output_name_map["Out"] = 0;
if(debug) cerr << "before infer shape" << endl; std::vector<std::vector<Variable*>> in_values;
static_cast<const framework::OperatorWithKernel*>(copy_op)->InferShape( &copy_infer_shape_ctx ); std::vector<std::vector<Variable*>> out_values;
if(debug) cerr << "infer shape" << endl; in_values.push_back({var});
//cerr << "fin infer shape" << endl; out_values.push_back({v});
RuntimeContextV2 copy_runtime_context(
in_values, out_values, copy_op_func_node.input_name_map,
copy_op_func_node.output_name_map);
// copy_runtime_context.input_values.push_back({var});
// copy_runtime_context.input_name_map["X"] = 0;
// copy_runtime_context.output_values.push_back({v});
// copy_runtime_context.output_name_map["Out"] = 0;
// copy_runtime_context.inputs.swap( copy_ins_value_map );
// copy_runtime_context.outputs.swap( copy_outs_value_map );
// std::cerr << "create runtime context" << std::endl;
RuntimeInferShapeContext copy_infer_shape_ctx(*copy_op,
copy_runtime_context);
if (debug) std::cerr << "before infer shape" << std::endl;
static_cast<const framework::OperatorWithKernel*>(copy_op)
->InferShape(&copy_infer_shape_ctx);
if (debug) std::cerr << "infer shape" << std::endl;
// std::cerr << "fin infer shape" << std::endl;
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
auto kernels_iter = all_op_kernels.find( "memcpy" ); auto kernels_iter = all_op_kernels.find("memcpy");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
kernels_iter, all_op_kernels.end(), platform::errors::Unavailable(
platform::errors::Unavailable("There are no kernels which are registered in the memcpy operator.") ); "There are no kernels which are registered in "
"the memcpy operator."));
//cerr << "create kernel" << endl; // std::cerr << "create kernel" << std::endl;
using OpKernelFunc = std::function<void(const ExecutionContext&)>; using OpKernelFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap = using OpKernelMap = std::unordered_map<OpKernelType, OpKernelFunc,
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>; OpKernelType::Hash>;
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
//auto place = platform::CPUPlace(); // auto place = platform::CPUPlace();
//auto place = platform::CUDAPlace(0); // auto place = platform::CUDAPlace(0);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto copy_exec_ctx = ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context ); auto copy_exec_ctx = ExecutionContextV2(*copy_op, scope, *dev_ctx,
if (debug ) cerr << "21" << endl; copy_runtime_context);
auto expected_kernel_key = dynamic_cast<const framework::OperatorWithKernel*>(copy_op)->GetExpectedKernelType( copy_exec_ctx ); if (debug) std::cerr << "21" << std::endl;
if (debug ) cerr << "22" << endl; auto expected_kernel_key =
//cerr << "22" << endl; dynamic_cast<const framework::OperatorWithKernel*>(copy_op)
->GetExpectedKernelType(copy_exec_ctx);
if (debug) std::cerr << "22" << std::endl;
// std::cerr << "22" << std::endl;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
copy_op_func_node.kernel_func_ = OpKernelFunc( kernel_iter->second ); copy_op_func_node.kernel_func_ = OpKernelFunc(kernel_iter->second);
copy_op_func_node.kernel_func_( copy_exec_ctx ); copy_op_func_node.kernel_func_(copy_exec_ctx);
if(debug) cerr << "run exe ctx" << endl; if (debug) std::cerr << "run exe ctx" << std::endl;
op_list.push_back( copy_op );
vec_func_list.push_back( copy_op_func_node);
op_list.push_back(copy_op);
vec_func_list.push_back(copy_op_func_node);
var_name_item.second[i] = v; runtime_context.input_values[var_name_item.second][i] = v;
} }
} }
} }
op_list.push_back( op_base ); op_list.push_back(op_base);
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
if (debug ) cerr << "3" << endl; if (debug) std::cerr << "3" << std::endl;
op_func_node.kernel_func_ = OpKernelFunc(kernel_iter->second); op_func_node.kernel_func_ = OpKernelFunc(kernel_iter->second);
if (debug ) cerr << "3-1" << endl; if (debug) std::cerr << "3-1" << std::endl;
op_func_node.kernel_func_( exec_ctx ); op_func_node.kernel_func_(exec_ctx);
vec_func_list.push_back( op_func_node ); vec_func_list.push_back(op_func_node);
if (debug ) cerr << "5" << endl; if (debug) std::cerr << "5" << std::endl;
} }
} }
void exec_op_func_list(const std::vector<OpFuncNode>& vec_func_list,
std::vector<OperatorBase*>& op_list, // NOLINT
void exec_op_func_list( const std::vector<OpFuncNode>& vec_func_list,
std::vector< OperatorBase* >& op_list,
const VariableScope& var_scope, const VariableScope& var_scope,
const platform::Place& place) const platform::Place& place) {
{ for (size_t i = 0; i < vec_func_list.size(); ++i) {
for( size_t i = 0; i < vec_func_list.size(); ++i )
{
auto& func_node = vec_func_list[i]; auto& func_node = vec_func_list[i];
auto op_base = op_list[i]; auto op_base = op_list[i];
// build runtime cost // build runtime cost
VariableValueMap ins_map; // VariableValueMap ins_map;
for( auto& var_name_item : func_node.input_index) std::vector<std::vector<Variable*>> ins_map;
{ for (auto& var_name_item : func_node.input_name_map) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size()); input_vars.reserve(func_node.input_index[var_name_item.second].size());
for (auto& id : var_name_item.second) { for (auto& id : func_node.input_index[var_name_item.second]) {
//cerr << var_name_item.first << "\t " << id << endl; // std::cerr << var_name_item.first << "\t " << id << std::endl;
input_vars.emplace_back( var_scope.var_list[ id ].get() ); input_vars.emplace_back(var_scope.var_list[id].get());
} }
ins_map.emplace( var_name_item.first, std::move(input_vars) ); // ins_map.emplace( var_name_item.first, std::move(input_vars) );
ins_map.emplace_back(std::move(input_vars));
} }
VariableValueMap outs_map; // VariableValueMap outs_map;
for( auto& var_name_item : func_node.output_index) std::vector<std::vector<Variable*>> outs_map;
{ for (auto& var_name_item : func_node.output_name_map) {
std::vector<Variable*> out_vars; std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size()); out_vars.reserve(func_node.output_index[var_name_item.second].size());
for (auto& id : var_name_item.second) { for (auto& id : func_node.output_index[var_name_item.second]) {
//cerr << var_name_item.first << "\t " << id << endl; // std::cerr << var_name_item.first << "\t " << id << std::endl;
out_vars.emplace_back( var_scope.var_list[ id ].get()); out_vars.emplace_back(var_scope.var_list[id].get());
} }
outs_map.emplace( var_name_item.first, std::move( out_vars ) ); // outs_map.emplace( var_name_item.first, std::move( out_vars ) );
outs_map.emplace_back(std::move(out_vars));
} }
RuntimeContext runtime_context( {}, {}); RuntimeContextV2 runtime_context(
runtime_context.inputs.swap( ins_map ); ins_map, outs_map, func_node.input_name_map, func_node.output_name_map);
runtime_context.outputs.swap( outs_map ); // runtime_context.inputs.swap( ins_map );
// runtime_context.outputs.swap( outs_map );
// runtime_context.input_values.swap(ins_map);
// runtime_context.output_values.swap(outs_map);
// runtime_context.input_name_map = func_node.input_name_map;
// runtime_context.output_name_map = func_node.output_name_map;
RuntimeInferShapeContext infer_shape_ctx( *op_base, runtime_context); RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context);
//dynamic_cast<const framework::OperatorWithKernel*>(op_base)->InferShape( &infer_shape_ctx );
//RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context);
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape( &infer_shape_ctx );
// dynamic_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
// &infer_shape_ctx );
// RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context);
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
//auto place = platform::CPUPlace(); // auto place = platform::CPUPlace();
//auto place = platform::CUDAPlace(0); // auto place = platform::CUDAPlace(0);
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto exec_context =
ExecutionContextV2(*op_base, scope, *dev_ctx, runtime_context);
auto exec_context = ExecutionContext(*op_base, scope, *dev_ctx, runtime_context ); func_node.kernel_func_(exec_context);
func_node.kernel_func_( exec_context );
} }
} }
class InterpreterCore class InterpreterCore {
{ public:
public: InterpreterCore(const platform::Place& place, const ProgramDesc& prog,
InterpreterCore( const platform::Place& place, const ProgramDesc& prog, const ProgramDesc& startup_prog) : place_(place), prog_(prog) { const ProgramDesc& startup_prog)
: place_(place), prog_(prog) {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
is_build = false; is_build = false;
paddle::framework::build_variable_scope( startup_prog, &global_scope ); paddle::framework::build_variable_scope(startup_prog, &global_scope);
std::vector<paddle::framework::OpFuncNode> vec_func_list; std::vector<paddle::framework::OpFuncNode> vec_func_list;
std::vector< paddle::framework::OperatorBase* > op_list; std::vector<paddle::framework::OperatorBase*> op_list;
paddle::framework::build_op_func_list( startup_prog, op_list, vec_func_list, &global_scope, place_); paddle::framework::build_op_func_list(startup_prog, op_list, vec_func_list,
&global_scope, place_);
} }
void run( const std::vector<std::string> vec_name, const std::vector<framework::Tensor>& vec_tensor, const vector<std::string>& vec_fetch_name, void run(const std::vector<std::string> vec_name,
std::vector<framework::Tensor>& vec_out) const std::vector<framework::Tensor>& vec_tensor,
{ const std::vector<std::string>& vec_fetch_name,
//cerr << "run" << endl; std::vector<framework::Tensor>& vec_out) { // NOLINT
// std::cerr << "run" << std::endl;
// set static data // set static data
if( is_build == false ) if (is_build == false) {
{ paddle::framework::build_variable_scope(prog_, &global_scope);
paddle::framework::build_variable_scope( prog_, &global_scope );
} }
for ( size_t i = 0; i < vec_name.size(); ++i )
{
auto it = global_scope.name2id.find( vec_name[i] );
//cerr << "find " << ( it != global_scope.name2id.end() ) <<endl;
assert( it != global_scope.name2id.end() );
auto feed_tensor = global_scope.var_list[ it->second]->GetMutable<framework::LoDTensor>(); for (size_t i = 0; i < vec_name.size(); ++i) {
//cerr << " get tensor" << endl; auto it = global_scope.name2id.find(vec_name[i]);
feed_tensor->ShareDataWith( vec_tensor[i] ); // std::cerr << "find " << (it != global_scope.name2id.end()) <<
//cerr << "share buffer with" << endl; // std::endl;
assert(it != global_scope.name2id.end());
auto feed_tensor =
global_scope.var_list[it->second]->GetMutable<framework::LoDTensor>();
// std::cerr << " get tensor" << std::endl;
feed_tensor->ShareDataWith(vec_tensor[i]);
// std::cerr << "share buffer with" << std::endl;
} }
if( is_build == false ) if (is_build == false) {
{ paddle::framework::build_op_func_list(prog_, op_list, vec_func_list,
paddle::framework::build_op_func_list( prog_, op_list, vec_func_list, &global_scope, place_); &global_scope, place_);
is_build = true; is_build = true;
} else {
paddle::framework::exec_op_func_list(vec_func_list, op_list, global_scope,
place_);
} }
else
{
paddle::framework::exec_op_func_list( vec_func_list, op_list, global_scope, place_ );
}
for( size_t i = 0; i < vec_fetch_name.size(); ++i )
{
auto it = global_scope.name2id.find( vec_fetch_name[i] );
assert( it != global_scope.name2id.end() );
auto fetch_tensor = global_scope.var_list[ it->second]->GetMutable<framework::LoDTensor>(); for (size_t i = 0; i < vec_fetch_name.size(); ++i) {
auto it = global_scope.name2id.find(vec_fetch_name[i]);
assert(it != global_scope.name2id.end());
auto fetch_tensor =
global_scope.var_list[it->second]->GetMutable<framework::LoDTensor>();
//cerr << "out " << fetch_tensor->data<float>()[0] << endl; // std::cerr << "out " << fetch_tensor->data<float>()[0] << std::endl;
if ( platform::is_gpu_place(fetch_tensor->place() ) ) if (platform::is_gpu_place(fetch_tensor->place())) {
{ // std::cerr << "fetch gpu" << std::endl;
//cerr << "fetch gpu" << endl;
Tensor out; Tensor out;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place_); auto* dev_ctx = pool.Get(place_);
dev_ctx->Wait(); dev_ctx->Wait();
TensorCopySync(*fetch_tensor, platform::CPUPlace(), &out); TensorCopySync(*fetch_tensor, platform::CPUPlace(), &out);
dev_ctx->Wait(); dev_ctx->Wait();
//cerr << "out " << out << endl; // std::cerr << "out " << out << std::endl;
//cout << out.data<float>()[0] << endl; vec_out.push_back(out);
vec_out.push_back( out ); } else {
} // std::cerr << "out " << *fetch_tensor << std::endl;
else
{
cerr << "out " << *fetch_tensor << endl;
} }
} }
} }
private:
private:
const platform::Place& place_; const platform::Place& place_;
const ProgramDesc& prog_; const ProgramDesc& prog_;
paddle::framework::VariableScope global_scope; paddle::framework::VariableScope global_scope;
std::vector<paddle::framework::OpFuncNode> vec_func_list; std::vector<paddle::framework::OpFuncNode> vec_func_list;
std::vector< paddle::framework::OperatorBase* > op_list; std::vector<paddle::framework::OperatorBase*> op_list;
bool is_build; bool is_build;
}; };
} // namespace framework
} } // namespace paddle
}
#include <iostream> // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#include <string> //
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gperftools/profiler.h>
#include <chrono>
#include <iostream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -9,69 +23,58 @@ ...@@ -9,69 +23,58 @@
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_exec.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/fluid/framework/new_exec.h" #include "paddle/fluid/pybind/pybind.h"
#include <chrono>
#include <gperftools/profiler.h>
int main() int main() {
{
paddle::framework::InitDevices(); paddle::framework::InitDevices();
paddle::framework::VariableScope global_scope; paddle::framework::VariableScope global_scope;
auto place = paddle::platform::CUDAPlace(0);
{ {
auto test_prog = paddle::framework::load_from_file( "lm_startup_program"); auto test_prog = paddle::framework::load_from_file("lm_startup_program");
paddle::framework::build_variable_scope( test_prog, &global_scope ); paddle::framework::build_variable_scope(test_prog, &global_scope);
std::vector<paddle::framework::OpFuncNode> vec_func_list; std::vector<paddle::framework::OpFuncNode> vec_func_list;
std::vector<std::unique_ptr<paddle::framework::OperatorBase>> op_list; std::vector<paddle::framework::OperatorBase*> op_list;
paddle::framework::build_op_func_list( test_prog, op_list, vec_func_list, global_scope); paddle::framework::build_op_func_list(test_prog, op_list, vec_func_list,
&global_scope, place);
paddle::framework::exec_op_func_list( vec_func_list, op_list, global_scope ); paddle::framework::exec_op_func_list(vec_func_list, op_list, global_scope,
place);
} }
cerr << "run main" << endl; std::cerr << "run main" << std::endl;
auto main_prog = paddle::framework::load_from_file( "lm_main_program"); auto main_prog = paddle::framework::load_from_file("lm_main_program");
paddle::framework::build_variable_scope( main_prog, &global_scope );
paddle::framework::build_variable_scope(main_prog, &global_scope);
std::vector<paddle::framework::OpFuncNode> vec_main_func_list; std::vector<paddle::framework::OpFuncNode> vec_main_func_list;
std::vector<std::unique_ptr<paddle::framework::OperatorBase>> op_main_list; std::vector<paddle::framework::OperatorBase*> op_main_list;
paddle::framework::build_op_func_list( main_prog, op_main_list, vec_main_func_list, global_scope); paddle::framework::build_op_func_list(
main_prog, op_main_list, vec_main_func_list, &global_scope, place);
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
ProfilerStart("new_executor.prof"); // ProfilerStart("new_executor.prof");
for ( size_t i = 0; i < 2320; ++i ) for (size_t i = 0; i < 2320; ++i) {
{ if (i % 200 == 0) {
if( i % 200 == 0) std::cerr << i << std::endl;
{
cerr << i << endl;
} }
paddle::framework::exec_op_func_list( vec_main_func_list, op_main_list, global_scope ); paddle::framework::exec_op_func_list(vec_main_func_list, op_main_list,
global_scope, place);
33
} }
ProfilerStop(); // ProfilerStop();
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = end-start; std::chrono::duration<double> diff = end - start;
cerr << "time cost " << diff.count() << endl;
std::cerr << "time cost " << diff.count() << std::endl;
return 1; return 1;
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -586,7 +583,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -586,7 +583,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
public: public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {} : op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
const auto& ins = ctx_.inputs; const auto& ins = ctx_.inputs;
...@@ -602,7 +598,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -602,7 +598,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Input %s should not contain more than one inputs.", name)); "Input %s should not contain more than one inputs.", name));
return in[0] != nullptr; return in[0] != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
// has only one output // has only one output
const auto& outs = ctx_.outputs; const auto& outs = ctx_.outputs;
...@@ -620,7 +615,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -620,7 +615,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Output %s should not contain more than one outputs.", name)); "Output %s should not contain more than one outputs.", name));
return out[0] != nullptr; return out[0] != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs; const auto& ins = ctx_.inputs;
auto it = ins.find(name); auto it = ins.find(name);
...@@ -634,7 +628,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -634,7 +628,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
return true; return true;
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
const auto& outs = ctx_.outputs; const auto& outs = ctx_.outputs;
auto it = outs.find(name); auto it = outs.find(name);
...@@ -648,17 +641,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -648,17 +641,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
return true; return true;
} }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
std::vector<std::string> Inputs(const std::string& name) const override { std::vector<std::string> Inputs(const std::string& name) const override {
return op_.Inputs(name); return op_.Inputs(name);
} }
std::vector<std::string> Outputs(const std::string& name) const override { std::vector<std::string> Outputs(const std::string& name) const override {
return op_.Outputs(name); return op_.Outputs(name);
} }
std::string GetInputNameByIdx(size_t idx) const override { std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto = auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
...@@ -669,7 +658,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -669,7 +658,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
op_.Type(), idx, op_proto->inputs().size())); op_.Type(), idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name(); return op_proto->inputs()[idx].name();
} }
std::string GetOutputNameByIdx(size_t idx) const override { std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto = auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
...@@ -681,7 +669,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -681,7 +669,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
op_.Type(), idx, op_proto->outputs().size())); op_.Type(), idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name(); return op_proto->outputs()[idx].name();
} }
void ShareDim(const std::string& in, const std::string& out, size_t i = 0, void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override { size_t j = 0) override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.inputs.find(in);
...@@ -702,16 +689,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -702,16 +689,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
"The index of output dimension is out of range, " "The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.", "excepted index less than %zu, but received %zu.",
out_it->second.size(), j)); out_it->second.size(), j));
Variable* in_var = in_it->second[i]; Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j]; Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var->Type(), out_var->Type(), in_var->Type(), out_var->Type(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in, "The type of input (%s) and output (%s) are inconsistent.", in,
out)); out));
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>(); auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
...@@ -728,7 +712,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -728,7 +712,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"or SelectedRows.")); "or SelectedRows."));
} }
} }
void ShareAllLoD(const std::string& in, void ShareAllLoD(const std::string& in,
const std::string& out) const override { const std::string& out) const override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.inputs.find(in);
...@@ -740,23 +723,18 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -740,23 +723,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_it, ctx_.outputs.end(), out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output [%s] found error in Op [%s]", out, platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
op_.Type())); op_.Type()));
auto& in_var_list = in_it->second; auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second; auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_list.size(), out_var_list.size(), in_var_list.size(), out_var_list.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size", "Op [%s]: Input var size should be equal with output var size",
op_.Type())); op_.Type()));
auto& out_var_names = op_.Outputs(out); auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) { for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) { if (out_var_names[i] == framework::kEmptyVarName) {
continue; continue;
} }
Variable* in_var = in_var_list[i]; Variable* in_var = in_var_list[i];
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_var_list[i]; Variable* out_var = out_var_list[i];
...@@ -773,7 +751,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -773,7 +751,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
} }
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.inputs.find(in);
...@@ -794,7 +771,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -794,7 +771,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"The index of output dimension is out of range, " "The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.", "excepted index less than %zu, but received %zu.",
out_it->second.size(), j)); out_it->second.size(), j));
Variable* in_var = in_it->second.at(i); Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_it->second.at(j); Variable* out_var = out_it->second.at(j);
...@@ -805,7 +781,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -805,7 +781,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto& in_tensor = in_var->Get<LoDTensor>(); auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>(); auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod()); out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators. // TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops. // Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor? // Shall we have a better method to shared info between in/out Tensor?
...@@ -826,14 +801,12 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -826,14 +801,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
#endif #endif
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of " "GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be " "output's actual lod is different among operators so that should be "
"set in the runtime kernel.")); "set in the runtime kernel."));
} }
void SetLoDLevel(const std::string& out, int32_t lod_level, void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
...@@ -841,9 +814,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -841,9 +814,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
"output's actual lod is different among operators so that should be " "output's actual lod is different among operators so that should be "
"set in the runtime kernel.")); "set in the runtime kernel."));
} }
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override { const std::string& name) override {
...@@ -853,7 +824,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -853,7 +824,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
} }
std::vector<InferShapeVarPtr> GetOutputVarPtrs( std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override { const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name); const std::vector<Variable*>& vars = OutputVars(name);
...@@ -862,7 +832,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -862,7 +832,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
} }
DDim GetInputDim(const std::string& name) const override { DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -872,22 +841,18 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -872,22 +841,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
name, vars.size())); name, vars.size()));
return this->GetDim(vars[0]); return this->GetDim(vars[0]);
} }
std::vector<DDim> GetInputsDim(const std::string& name) const override { std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars); return GetDims(vars);
} }
std::vector<proto::VarType::Type> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override { const std::string& name) const override {
return GetVarTypes(InputVars(name)); return GetVarTypes(InputVars(name));
} }
std::vector<proto::VarType::Type> GetOutputsVarType( std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override { const std::string& name) const override {
return GetVarTypes(OutputVars(name)); return GetVarTypes(OutputVars(name));
} }
void SetOutputDim(const std::string& name, const DDim& dim) override { void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name); auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -897,13 +862,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -897,13 +862,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
name, vars.size())); name, vars.size()));
SetDim(vars[0], dim); SetDim(vars[0], dim);
} }
void SetOutputsDim(const std::string& name, void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name); auto& vars = OutputVars(name);
SetDims(vars, dims); SetDims(vars, dims);
} }
protected: protected:
DDim GetDim(Variable* var) const { DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -919,7 +882,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -919,7 +882,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
ToTypeName(var->Type()))); ToTypeName(var->Type())));
} }
} }
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const { std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret; std::vector<DDim> ret;
ret.reserve(vars.size()); ret.reserve(vars.size());
...@@ -927,12 +889,10 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -927,12 +889,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
[this](Variable* var) { return this->GetDim(var); }); [this](Variable* var) { return this->GetDim(var); });
return ret; return ret;
} }
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time.")); "GetRepeatedDims method only ban be used in compile time."));
} }
void SetDim(Variable* var, const DDim& dim) { void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
...@@ -945,7 +905,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -945,7 +905,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
ToTypeName(var->Type()))); ToTypeName(var->Type())));
} }
} }
void SetDims(const std::vector<Variable*>& vars, void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) { const std::vector<DDim>& dims) {
size_t length = vars.size(); size_t length = vars.size();
...@@ -962,13 +921,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -962,13 +921,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
SetDim(vars[i], dims[i]); SetDim(vars[i], dims[i]);
} }
} }
void SetRepeatedDims(const std::string& name, void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time.")); "SetRepeatedDims method only can be used in compile time."));
} }
std::vector<proto::VarType::Type> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const { const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv; std::vector<proto::VarType::Type> retv;
...@@ -978,11 +935,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -978,11 +935,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
this, std::placeholders::_1)); this, std::placeholders::_1));
return retv; return retv;
} }
proto::VarType::Type GetVarType(Variable* var) const { proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
private: private:
const std::vector<Variable*>& InputVars(const std::string& name) const { const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name); auto it = ctx_.inputs.find(name);
...@@ -992,7 +947,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -992,7 +947,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Operator (%s) does not have the input (%s).", op_.Type(), name)); "Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second; return it->second;
} }
const std::vector<Variable*>& OutputVars(const std::string& name) const { const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name); auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -1001,7 +955,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -1001,7 +955,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Operator (%s) does not have the outputs (%s).", op_.Type(), name)); "Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second; return it->second;
} }
const OperatorBase& op_; const OperatorBase& op_;
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册