diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 08e912f52ccb570af8e1c10fc95480d479a2c6eb..5f0649ddc69aa6b9027e4fb51419a1c1ed932742 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -427,6 +427,9 @@ else() cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) endif() +#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) +#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) + set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) diff --git a/paddle/fluid/framework/new_exec.h b/paddle/fluid/framework/new_exec.h new file mode 100644 index 0000000000000000000000000000000000000000..cecbdb45a9cc9a2ced0a00069dda4e0b2bab2c48 --- /dev/null +++ b/paddle/fluid/framework/new_exec.h @@ -0,0 +1,618 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/new_exec_util.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/init.h" + +// USE_OP(fill_constant); +// USE_OP(elementwise_add); + +// using namespace std; + +namespace paddle { +namespace framework { + +using std::cerr; +using std::endl; + +using OpKernelComputeFunc = std::function; +using OpKernelMap = + std::unordered_map; + +framework::ProgramDesc load_from_file(const std::string& file_name) { + std::ifstream fin(file_name, std::ios::in | std::ios::binary); + fin.seekg(0, std::ios::end); + std::string buffer(fin.tellg(), ' '); + fin.seekg(0, std::ios::beg); + fin.read(&buffer[0], buffer.size()); + fin.close(); + + ProgramDesc program_desc(buffer); + return program_desc; +} + +struct OpKernelFunc { + OpKernelComputeFunc compute_func_; + OperatorBase* operator_base_; +}; + +struct VariableMetaInfo { + int var_ref_count_; +}; + +struct VariableScope { + std::vector var_list; + std::map name2id; + std::vector vec_meta_info_; +}; + +struct NextInstruction { + std::vector direct_run_; +}; + +struct EventInter {}; + +struct InstructionInfo { + std::vector dependecy_count_; +}; + +struct EventRun { + EventInter event_inter; + std::vector same_device_run_; + std::vector synchronized_run; +}; + +struct Instruction { + OpKernelFunc kernel_func_; + std::map> input_index_; + std::map> output_index_; + + std::vector gc_check_var_list; + NextInstruction next_instruction_; + std::vector vec_event_list_; +}; + +struct OpFuncNode { + // int unsed; + std::map> input_index; + std::map> output_index; + + OpKernelComputeFunc kernel_func_; +}; + +int convert(const platform::Place& place) { + if (is_cpu_place(place)) { + return 0; + } + if (is_gpu_place(place)) { + return 1; + } + + return -1; +} + +std::vector merge_vec(const std::vector& first, + const std::vector& second) { + std::vector out(first.size() + second.size()); + std::merge(first.begin(), first.end(), second.begin(), second.end(), + out.begin()); + + std::vector::iterator it; + it = std::unique(out.begin(), out.end()); + + out.resize(std::distance(out.begin(), it)); + + return out; +} + +void build_variable_outer_scope(const framework::ProgramDesc& pdesc, + VariableScope* var_scope, Scope* outer_scope) { + auto& global_block = pdesc.Block(0); + + for (auto& var : global_block.AllVars()) { + if (var->Name() == framework::kEmptyVarName) { + continue; + } + auto v = outer_scope->Var(var->Name()); + + if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) { + var_scope->name2id[var->Name()] = var_scope->var_list.size(); + } + + InitializeVariable(v, var->GetType()); + var_scope->var_list.push_back(v); + } +} + +void build_variable_scope(const framework::ProgramDesc& pdesc, + VariableScope* var_scope) { + auto& global_block = pdesc.Block(0); + + for (auto& var : global_block.AllVars()) { + if (var->Name() == framework::kEmptyVarName) { + continue; + } + + if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) { + var_scope->name2id[var->Name()] = var_scope->var_list.size(); + } + + auto v = new Variable(); + InitializeVariable(v, var->GetType()); + var_scope->var_list.push_back(v); + } +} + +void build_op_func_list(const framework::ProgramDesc& pdesc, + std::vector* op_list, + std::vector* vec_func_list, + VariableScope* var_scope, + const platform::Place& place) { + auto& global_block = pdesc.Block(0); + + for (auto& op : global_block.AllOps()) { + VLOG(3) << op->Type(); + // << op->Type() << endl; + + auto& info = OpInfoMap::Instance().Get(op->Type()); + + const VariableNameMap& inputs_names = op->Inputs(); + const VariableNameMap& outputs_names = op->Outputs(); + AttributeMap op_attr_map = op->GetAttrMap(); + + if (info.Checker() != nullptr) { + info.Checker()->Check(&op_attr_map); + } + auto op_base = + info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); + + OpFuncNode op_func_node; + + VariableValueMap ins_map; + std::map> ins_name2id; + for (auto& var_name_item : inputs_names) { + std::vector input_vars; + std::vector vec_ids; + input_vars.reserve(var_name_item.second.size()); + for (auto& var_name : var_name_item.second) { + auto it = var_scope->name2id.find(var_name); + assert(it != var_scope->name2id.end()); + input_vars.push_back(var_scope->var_list[it->second]); + vec_ids.push_back(it->second); + } + ins_map[var_name_item.first] = input_vars; + ins_name2id[var_name_item.first] = vec_ids; + } + + VariableValueMap outs_map; + std::map> outs_name2id; + for (auto& var_name_item : outputs_names) { + std::vector output_vars; + std::vector vec_ids; + output_vars.reserve(var_name_item.second.size()); + for (auto& var_name : var_name_item.second) { + auto it = var_scope->name2id.find(var_name); + assert(it != var_scope->name2id.end()); + output_vars.push_back(var_scope->var_list[it->second]); + vec_ids.push_back(it->second); + } + outs_map[var_name_item.first] = output_vars; + outs_name2id[var_name_item.first] = vec_ids; + } + + op_func_node.input_index = ins_name2id; + op_func_node.output_index = outs_name2id; + RuntimeContext runtime_context({}, {}); + runtime_context.inputs.swap(ins_map); + runtime_context.outputs.swap(outs_map); + RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context); + static_cast(op_base)->InferShape( + &infer_shape_ctx); + auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); + auto kernels_iter = all_op_kernels.find(op->Type()); + PADDLE_ENFORCE_NE( + kernels_iter, all_op_kernels.end(), + platform::errors::Unavailable( + "There are no kernels which are registered in the %s operator.", + op->Type())); + + OpKernelMap& kernels = kernels_iter->second; + // auto place = platform::CPUPlace(); + // auto place = platform::CUDAPlace(0); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + Scope scope; + auto exec_ctx = + ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); + auto expected_kernel_key = + dynamic_cast(op_base) + ->GetExpectedKernelType(exec_ctx); + + VariableValueMap& ins_map_temp = runtime_context.inputs; + + for (auto& var_name_item : ins_map_temp) { + for (size_t i = 0; i < var_name_item.second.size(); ++i) { + auto var = var_name_item.second[i]; + auto tensor_in = static_cast(&(var->Get())); + if (!tensor_in->IsInitialized()) { + continue; + } + auto kernel_type_for_var = + static_cast(op_base) + ->GetKernelTypeForVar(var_name_item.first, *tensor_in, + expected_kernel_key); + if (!platform::is_same_place(kernel_type_for_var.place_, + expected_kernel_key.place_)) { + // need trans place + // 1. add var in scope + // 2. add copy op + std::string new_var_name = + "temp_1" + std::to_string(var_scope->var_list.size() + 1); + auto v = new Variable(); + v->GetMutable(); + var_scope->name2id[new_var_name] = var_scope->var_list.size(); + var_scope->var_list.push_back(v); + + VariableNameMap copy_in_map; + auto x_iter = inputs_names.find(var_name_item.first); + copy_in_map["X"] = {x_iter->second[i]}; + VariableNameMap copy_out_map; + copy_out_map["Out"] = {new_var_name}; + AttributeMap attr_map; + attr_map["dst_place_type"] = convert(place); + + std::map> copy_ins_name2id; + copy_ins_name2id["X"] = ins_name2id[var_name_item.first]; + std::map> copy_out_name2id; + copy_out_name2id["Out"] = {var_scope->name2id[new_var_name]}; + + op_func_node.input_index[var_name_item.first][i] = + var_scope->name2id[new_var_name]; + + VariableValueMap copy_ins_value_map; + copy_ins_value_map["X"] = {var}; + VariableValueMap copy_outs_value_map; + copy_outs_value_map["Out"] = {v}; + + auto& copy_info = OpInfoMap::Instance().Get("memcpy"); + auto copy_op = copy_info.Creator()("memcpy", copy_in_map, + copy_out_map, attr_map); + OpFuncNode copy_op_func_node; + copy_op_func_node.input_index = copy_ins_name2id; + copy_op_func_node.output_index = copy_out_name2id; + + RuntimeContext copy_runtime_context({}, {}); + copy_runtime_context.inputs.swap(copy_ins_value_map); + copy_runtime_context.outputs.swap(copy_outs_value_map); + RuntimeInferShapeContext copy_infer_shape_ctx(*copy_op, + copy_runtime_context); + static_cast(copy_op) + ->InferShape(©_infer_shape_ctx); + auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); + auto kernels_iter = all_op_kernels.find("memcpy"); + PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), + platform::errors::Unavailable( + "There are no kernels which are registered in " + "the memcpy operator.")); + + OpKernelMap& kernels = kernels_iter->second; + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + Scope scope; + auto copy_exec_ctx = + ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context); + auto expected_kernel_key = + dynamic_cast(copy_op) + ->GetExpectedKernelType(copy_exec_ctx); + auto kernel_iter = kernels.find(expected_kernel_key); + copy_op_func_node.kernel_func_ = + OpKernelComputeFunc(kernel_iter->second); + copy_op_func_node.kernel_func_(copy_exec_ctx); + op_list->push_back(copy_op); + vec_func_list->push_back(copy_op_func_node); + + var_name_item.second[i] = v; + } + } + } + + op_list->push_back(op_base); + + auto kernel_iter = kernels.find(expected_kernel_key); + PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), + platform::errors::NotFound( + "Operator (%s) does not have kernel for %s.", + op->Type(), KernelTypeToString(expected_kernel_key))); + + op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); + op_func_node.kernel_func_(exec_ctx); + vec_func_list->push_back(op_func_node); + } +} + +class InterpreterCore { + public: + InterpreterCore(const platform::Place& place, const ProgramDesc& prog, + const ProgramDesc& startup_prog, Scope* scope) + : place_(place), prog_(prog), outer_scope_(scope) { + paddle::framework::InitDevices(); + + is_build_ = false; + + if (outer_scope_ != nullptr) { + auto name_list = outer_scope_->LocalVarNames(); + for (auto name : name_list) { + auto v = outer_scope_->Var(name); + if (global_scope.name2id.find(name) == global_scope.name2id.end()) { + global_scope.name2id[name] = global_scope.var_list.size(); + } + + global_scope.var_list.push_back(v); + } + } + + paddle::framework::build_variable_outer_scope(startup_prog, &global_scope, + outer_scope_); + + std::vector vec_func_list; + std::vector op_list; + paddle::framework::build_op_func_list( + startup_prog, &op_list, &vec_func_list, &global_scope, place_); + // add variable to outer_scope + } + void run(const std::vector& vec_name, + const std::vector& vec_tensor, + const std::vector& vec_fetch_name, + std::vector* vec_out) { + if (is_build_ == false) { + paddle::framework::build_variable_scope(prog_, &global_scope); + } + for (size_t i = 0; i < vec_name.size(); ++i) { + auto it = global_scope.name2id.find(vec_name[i]); + assert(it != global_scope.name2id.end()); + + auto feed_tensor = + global_scope.var_list[it->second]->GetMutable(); + feed_tensor->ShareDataWith(vec_tensor[i]); + } + + if (is_build_ == false) { + paddle::framework::build_op_func_list(prog_, &op_list, &vec_func_list, + &global_scope, place_); + is_build_ = true; + // convert vec func_list to graph + convert(); + } else { + exec_instruction_list(vec_instruction_, global_scope, place_); + } + + for (size_t i = 0; i < vec_fetch_name.size(); ++i) { + auto it = global_scope.name2id.find(vec_fetch_name[i]); + assert(it != global_scope.name2id.end()); + PADDLE_ENFORCE_NE(it, global_scope.name2id.end(), + platform::errors::NotFound( + "Can't find (%d) the fetch var (%s) in scope", i, + vec_fetch_name[i])); + + auto fetch_tensor = + global_scope.var_list[it->second]->GetMutable(); + + if (platform::is_gpu_place(fetch_tensor->place())) { + Tensor out; + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place_); + dev_ctx->Wait(); + TensorCopySync(*fetch_tensor, platform::CPUPlace(), &out); + dev_ctx->Wait(); + vec_out->push_back(out); + } else { + Tensor out; + TensorCopySync(*fetch_tensor, platform::CPUPlace(), &out); + vec_out->push_back(out); + } + } + } + + private: + void convert() { + input_var2op_info_.resize(global_scope.var_list.size()); + + vec_instruction_.reserve(vec_func_list.size()); + dependecy_count_.resize(vec_func_list.size()); + global_scope.vec_meta_info_.resize(global_scope.var_list.size()); + for (size_t i = 0; i < vec_func_list.size(); ++i) { + Instruction temp_inst; + temp_inst.kernel_func_.compute_func_ = vec_func_list[i].kernel_func_; + temp_inst.kernel_func_.operator_base_ = op_list[i]; + temp_inst.input_index_ = vec_func_list[i].input_index; + temp_inst.output_index_ = vec_func_list[i].output_index; + + std::vector gc_check_input_list; + for (auto& item : vec_func_list[i].input_index) { + for (auto id : item.second) { + input_var2op_info_[id].push_back(i); + gc_check_input_list.push_back(id); + } + } + std::sort(gc_check_input_list.begin(), gc_check_input_list.end()); + auto last = + std::unique(gc_check_input_list.begin(), gc_check_input_list.end()); + gc_check_input_list.erase(last, gc_check_input_list.end()); + for (auto var_id : gc_check_input_list) { + global_scope.vec_meta_info_[var_id].var_ref_count_++; + } + + temp_inst.gc_check_var_list.swap(gc_check_input_list); + + vec_instruction_.push_back(temp_inst); + } + + for (size_t i = 0; i < vec_instruction_.size(); ++i) { + std::vector vec_temp; + for (auto& item : vec_instruction_[i].output_index_) { + for (auto id : item.second) { + vec_temp = merge_vec(vec_temp, input_var2op_info_[id]); + } + } + + // In Program, op order is a very import information. + // Op can noly add op after it as next as next ops. + std::vector filter_next; + filter_next.reserve(vec_temp.size()); + for (auto item : vec_temp) { + if (item > i) { + filter_next.push_back(item); + } + } + vec_instruction_[i].next_instruction_.direct_run_ = filter_next; + + for (auto inst_id : filter_next) { + dependecy_count_[inst_id]++; + } + } + } + + void run_instr(const Instruction& instr_node, const VariableScope& var_scope, + const platform::Place& place) { + auto op_base = instr_node.kernel_func_.operator_base_; + // build runtime cost + VariableValueMap ins_map; + for (auto& var_name_item : instr_node.input_index_) { + std::vector input_vars; + + input_vars.reserve(var_name_item.second.size()); + for (auto& id : var_name_item.second) { + input_vars.emplace_back(var_scope.var_list[id]); + } + ins_map.emplace(var_name_item.first, std::move(input_vars)); + } + + VariableValueMap outs_map; + for (auto& var_name_item : instr_node.output_index_) { + std::vector out_vars; + + out_vars.reserve(var_name_item.second.size()); + for (auto& id : var_name_item.second) { + out_vars.emplace_back(var_scope.var_list[id]); + } + outs_map.emplace(var_name_item.first, std::move(out_vars)); + } + + RuntimeContext runtime_context({}, {}); + runtime_context.inputs.swap(ins_map); + runtime_context.outputs.swap(outs_map); + + RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context); + + static_cast(op_base)->InferShape( + &infer_shape_ctx); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + Scope scope; + + auto exec_context = + ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); + + instr_node.kernel_func_.compute_func_(exec_context); + } + + void exec_instruction_list(const std::vector& vec_instr, + const VariableScope& var_scope, + const platform::Place& place) { + std::queue working_queue; + auto working_dependecy_count = dependecy_count_; + for (size_t i = 0; i < dependecy_count_.size(); ++i) { + if (dependecy_count_[i] == 0) { + working_queue.push(i); + } + } + + auto working_var_ref = global_scope.vec_meta_info_; + + size_t run_op_number = 0; + while (!working_queue.empty()) { + auto instr_id = working_queue.front(); + working_queue.pop(); + auto& instr_node = vec_instr[instr_id]; + run_instr(instr_node, var_scope, place); + + auto& next_instr = instr_node.next_instruction_.direct_run_; + ++run_op_number; + + for (auto next_i : next_instr) { + --working_dependecy_count[next_i]; + if (working_dependecy_count[next_i] == 0) { + working_queue.push(next_i); + } + } + + // GC infomation + + auto& gc_check_list = instr_node.gc_check_var_list; + for (auto var_id : gc_check_list) { + --working_var_ref[var_id].var_ref_count_; + } + } + + for (size_t i = 0; i < working_var_ref.size(); ++i) { + if (working_var_ref[i].var_ref_count_ != 0) { + cerr << " var ref is not zero " << i << endl; + } + } + } + + const platform::Place& place_; + const ProgramDesc& prog_; + paddle::framework::VariableScope global_scope; + std::vector vec_func_list; + std::vector op_list; + + bool is_build_; + + std::vector vec_instruction_; + + InstructionInfo instruction_info_; + + std::vector dependecy_count_; + std::vector ref_coun_info; + std::vector> input_var2op_info_; + + Scope* outer_scope_; +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_exec_test.cc b/paddle/fluid/framework/new_exec_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bfb6b6540cff8b597dddea9e6c7e7e726917765 --- /dev/null +++ b/paddle/fluid/framework/new_exec_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" + +#include "paddle/fluid/pybind/pybind.h" + +#include "gperftools/profiler.h" +#include "paddle/fluid/framework/new_exec.h" +#include "paddle/fluid/platform/init.h" + +int main() { + paddle::framework::InitDevices(); + paddle::framework::VariableScope global_scope; + auto place = paddle::platform::CUDAPlace(0); + auto test_prog = paddle::framework::load_from_file("lm_startup_program"); + { + paddle::framework::build_variable_scope(test_prog, &global_scope); + + std::vector vec_func_list; + std::vector op_list; + paddle::framework::build_op_func_list(test_prog, op_list, vec_func_list, + &global_scope, place); + + // paddle::framework::exec_op_func_list( vec_func_list, op_list, + // global_scope, place ); + } + + cerr << "run main" << endl; + auto main_prog = paddle::framework::load_from_file("lm_main_program"); + + paddle::framework::build_variable_scope(main_prog, &global_scope); + + std::vector vec_main_func_list; + std::vector op_main_list; + paddle::framework::build_op_func_list( + main_prog, op_main_list, vec_main_func_list, &global_scope, place); + paddle::framework::Scope scope; + paddle::framework::InterpreterCore interp_core(place, main_prog, test_prog, + &scope); + auto start = std::chrono::steady_clock::now(); + ProfilerStart("new_executor.prof"); + for (size_t i = 0; i < 2320; ++i) { + if (i % 200 == 0) { + cerr << i << endl; + } + // paddle::framework::exec_op_func_list( vec_main_func_list, op_main_list, + // global_scope, place ); + std::vector vec_out; + interp_core.run({}, {}, {}, vec_out); + } + ProfilerStop(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration diff = end - start; + + cerr << "time cost " << diff.count() << endl; + + return 1; +} diff --git a/paddle/fluid/framework/new_exec_util.h b/paddle/fluid/framework/new_exec_util.h new file mode 100644 index 0000000000000000000000000000000000000000..1783b9be74becfff70967bf132e2d609b7e6b8a6 --- /dev/null +++ b/paddle/fluid/framework/new_exec_util.h @@ -0,0 +1,472 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/************************************************************************* + > File Name: new_exec_util.h + > Author: guanshanshan@baidu.com + > Created Time: Fri 23 Jul 2021 06:19:19 AM UTC + ************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/init.h" + +namespace paddle { +namespace framework { + +class RuntimeInferShapeContext : public InferShapeContext { + public: + RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) + : op_(op), ctx_(ctx) {} + + bool HasInput(const std::string& name) const override { + // has only one input + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end()) { + return false; + } + const auto& in = it->second; + if (in.size() == 0) return false; + PADDLE_ENFORCE_EQ( + in.size(), 1UL, + platform::errors::InvalidArgument( + "Input %s should not contain more than one inputs.", name)); + return in[0] != nullptr; + } + + bool HasOutput(const std::string& name) const override { + // has only one output + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end()) { + return false; + } + const auto& out = it->second; + if (out.size() == 0) { + return false; + } + PADDLE_ENFORCE_EQ( + out.size(), 1UL, + platform::errors::InvalidArgument( + "Output %s should not contain more than one outputs.", name)); + return out[0] != nullptr; + } + + bool HasInputs(const std::string& name) const override { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end() || it->second.empty()) { + return false; + } + for (auto& input : it->second) { + if (input == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const override { + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end() || it->second.empty()) { + return false; + } + for (auto& output : it->second) { + if (output == nullptr) { + return false; + } + } + return true; + } + + AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } + + std::vector Inputs(const std::string& name) const override { + return op_.Inputs(name); + } + + std::vector Outputs(const std::string& name) const override { + return op_.Outputs(name); + } + + std::string GetInputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->inputs().size())); + return op_proto->inputs()[idx].name(); + } + + std::string GetOutputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT( + idx, op_proto->outputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of outputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->outputs().size())); + return op_proto->outputs()[idx].name(); + } + + void ShareDim(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) override { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE( + in_it, ctx_.inputs.end(), + platform::errors::NotFound("Input %s does not exist.", in)); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output %s does not exist.", out)); + PADDLE_ENFORCE_LT(i, in_it->second.size(), + platform::errors::InvalidArgument( + "The index of input dimension is out of range, " + "excepted index less than %zu, but received %zu.", + in_it->second.size(), i)); + PADDLE_ENFORCE_LT(j, out_it->second.size(), + platform::errors::InvalidArgument( + "The index of output dimension is out of range, " + "excepted index less than %zu, but received %zu.", + out_it->second.size(), j)); + + Variable* in_var = in_it->second[i]; + Variable* out_var = out_it->second[j]; + + PADDLE_ENFORCE_EQ( + in_var->Type(), out_var->Type(), + platform::errors::InvalidArgument( + "The type of input (%s) and output (%s) are inconsistent.", in, + out)); + + if (in_var->IsType()) { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } else if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, the input type of ShareDim only can be LoDTensor " + "or SelectedRows.")); + } + } + + void ShareAllLoD(const std::string& in, + const std::string& out) const override { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(), + platform::errors::NotFound( + "Input [%s] found error in Op [%s]", in, op_.Type())); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output [%s] found error in Op [%s]", out, + op_.Type())); + + auto& in_var_list = in_it->second; + auto& out_var_list = out_it->second; + + PADDLE_ENFORCE_EQ( + in_var_list.size(), out_var_list.size(), + platform::errors::PreconditionNotMet( + "Op [%s]: Input var size should be equal with output var size", + op_.Type())); + + auto& out_var_names = op_.Outputs(out); + + for (size_t i = 0; i < in_var_list.size(); ++i) { + if (out_var_names[i] == framework::kEmptyVarName) { + continue; + } + + Variable* in_var = in_var_list[i]; + if (!in_var->IsType()) return; + Variable* out_var = out_var_list[i]; + PADDLE_ENFORCE_EQ(out_var->IsType(), true, + platform::errors::PreconditionNotMet( + "The %d-th output of Output(%s) must be LoDTensor.", + i, out_var_names[i])); + auto& in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); +#ifdef PADDLE_WITH_MKLDNN + if (in_tensor.layout() != DataLayout::kMKLDNN) +#endif + out_tensor->set_layout(in_tensor.layout()); + } + } + + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE( + in_it, ctx_.inputs.end(), + platform::errors::NotFound("Input %s does not exist.", in)); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output %s does not exist.", out)); + PADDLE_ENFORCE_LT(i, in_it->second.size(), + platform::errors::InvalidArgument( + "The index of input dimension is out of range, " + "excepted index less than %zu, but received %zu.", + in_it->second.size(), i)); + PADDLE_ENFORCE_LT(j, out_it->second.size(), + platform::errors::InvalidArgument( + "The index of output dimension is out of range, " + "excepted index less than %zu, but received %zu.", + out_it->second.size(), j)); + + Variable* in_var = in_it->second.at(i); + if (!in_var->IsType()) return; + Variable* out_var = out_it->second.at(j); + PADDLE_ENFORCE_EQ( + out_var->IsType(), true, + platform::errors::InvalidArgument( + "The %zu-th output of Output(%s) must be LoDTensor.", j, out)); + auto& in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); + +// TODO(dzhwinter) : reuse ShareLoD in most operators. +// Need to call ShareLayout explicitly in sequence related ops. +// Shall we have a better method to shared info between in/out Tensor? +#ifdef PADDLE_WITH_MKLDNN + // Fix me: ugly workaround below + // Correct solution: + // set_layout() should NOT be called here (i.e. ShareLoD). Instead, + // layout of output tensor should be set "manually" in Compute() + // of each OPKernel. The reason layout should NOT be shared between + // input and output "automatically" (now by InferShape()->ShareLoD()) + // is that layout transform may occur after InferShape(). + // Workaround: + // Skip set_layout() when input layout is kMKLDNN + // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN + // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called + // in Compute() + if (in_tensor.layout() != DataLayout::kMKLDNN) +#endif + out_tensor->set_layout(in_tensor.layout()); + } + + int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GetLoDLevel is only used in compile time. The calculation of " + "output's actual lod is different among operators so that should be " + "set in the runtime kernel.")); + } + + void SetLoDLevel(const std::string& out, int32_t lod_level, + size_t j = 0) const override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "SetLoDLevel is only used in compile time. The calculation of " + "output's actual lod is different among operators so that should be " + "set in the runtime kernel.")); + } + + bool IsRuntime() const override { return true; } + + // TODO(paddle-dev): Can this be template? + std::vector GetInputVarPtrs( + const std::string& name) override { + const std::vector& vars = InputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + + std::vector GetOutputVarPtrs( + const std::string& name) override { + const std::vector& vars = OutputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + + DDim GetInputDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + PADDLE_ENFORCE_EQ( + vars.size(), 1UL, + platform::errors::InvalidArgument( + "Input(%s) should hold one element, but now it holds %zu elements.", + name, vars.size())); + return this->GetDim(vars[0]); + } + + std::vector GetInputsDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + return GetDims(vars); + } + + std::vector GetInputsVarType( + const std::string& name) const override { + return GetVarTypes(InputVars(name)); + } + + std::vector GetOutputsVarType( + const std::string& name) const override { + return GetVarTypes(OutputVars(name)); + } + + void SetOutputDim(const std::string& name, const DDim& dim) override { + auto& vars = OutputVars(name); + PADDLE_ENFORCE_EQ( + vars.size(), 1UL, + platform::errors::InvalidArgument("Output(%s) should hold one element, " + "but now it holds %zu elements.", + name, vars.size())); + SetDim(vars[0], dim); + } + + void SetOutputsDim(const std::string& name, + const std::vector& dims) override { + auto& vars = OutputVars(name); + SetDims(vars, dims); + } + + protected: + DDim GetDim(Variable* var) const { + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::InvalidArgument("Input variable is nullptr.")); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only LoDTensor or SelectedRows support 'GetDim', but input " + "Variable's type is %s.", + ToTypeName(var->Type()))); + } + } + + std::vector GetDims(const std::vector& vars) const { + std::vector ret; + ret.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(ret), + [this](Variable* var) { return this->GetDim(var); }); + return ret; + } + + std::vector GetRepeatedDims(const std::string& name) const override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GetRepeatedDims method only ban be used in compile time.")); + } + + void SetDim(Variable* var, const DDim& dim) { + if (var->IsType()) { + var->GetMutable()->Resize(dim); + } else if (var->IsType()) { + var->GetMutable()->set_height(dim[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Variable type error, expect LoDTensor or SelectedRows, but received " + "(%s).", + ToTypeName(var->Type()))); + } + } + + void SetDims(const std::vector& vars, + const std::vector& dims) { + size_t length = vars.size(); + PADDLE_ENFORCE_EQ(length, dims.size(), + platform::errors::InvalidArgument( + "The number of input variables do not match the " + "number of input dimensions, the number of variables " + "is %zu, the number of dimensions is %zu.", + length, dims.size())); + for (size_t i = 0; i < length; ++i) { + if (vars[i] == nullptr) { + continue; + } + SetDim(vars[i], dims[i]); + } + } + + void SetRepeatedDims(const std::string& name, + const std::vector& dims) override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "SetRepeatedDims method only can be used in compile time.")); + } + + std::vector GetVarTypes( + const std::vector& vars) const { + std::vector retv; + retv.resize(vars.size()); + std::transform(vars.begin(), vars.end(), retv.begin(), + std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), + this, std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(Variable* var) const { + return ToVarType(var->Type()); + } + + private: + const std::vector& InputVars(const std::string& name) const { + auto it = ctx_.inputs.find(name); + PADDLE_ENFORCE_NE( + it, ctx_.inputs.end(), + platform::errors::NotFound( + "Operator (%s) does not have the input (%s).", op_.Type(), name)); + return it->second; + } + + const std::vector& OutputVars(const std::string& name) const { + auto it = ctx_.outputs.find(name); + PADDLE_ENFORCE_NE( + it, ctx_.outputs.end(), + platform::errors::NotFound( + "Operator (%s) does not have the outputs (%s).", op_.Type(), name)); + return it->second; + } + + const OperatorBase& op_; + const RuntimeContext& ctx_; +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index d465e77ea1886f7f35549a043951048fb2bcb61d..0dcbb6e727de78534cd75ab8f65516d22669ade8 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -36,7 +36,6 @@ class FillConstantOp : public framework::OperatorWithKernel { i, shape[i], framework::make_ddim(shape))); } } - if (shape.empty() && ctx->HasInput("ShapeTensor")) { auto shape_dims = ctx->GetInputDim("ShapeTensor"); int num_ele = 1; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 1c36cebe70a77ebe0547bede3bcf6e35bec86ffe..af01b71adb78e3034811ac2db5978172e93dd993 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -591,7 +591,6 @@ class ReduceGradOp : public framework::OperatorWithKernel { (in_dtype >= 0) ? static_cast(in_dtype) : OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - #ifdef PADDLE_WITH_MKLDNN auto CanMKLDNNReduceGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index 74316841a13b1771cbe815b6b0180a4747e9df70..29528ae0d29925b4a343c7a8e3bfe2d6f19a53c6 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -111,15 +111,12 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { const Tensor* labels = context.Input("Label"); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); - const Tensor* softmax = context.Input("Softmax"); const bool use_softmax = context.Attr("use_softmax"); - if (logit_grad != softmax || !use_softmax) { framework::TensorCopy(*softmax, context.GetPlace(), context.device_context(), logit_grad); } - const bool soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); @@ -133,7 +130,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim}); - auto out_grad_mat = framework::EigenMatrix::From(out_grad_2d); auto logit_grad_mat = framework::EigenMatrix::From(logit_grad_2d); auto& place = *context.template device_context() @@ -147,9 +143,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { logit_grad_mat.device(place) = out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * logit_grad_mat; - } - // use_softmax step2 - else { + } else { + // use_softmax step2 const int64_t* label_data = labels->data(); T* logit_grad_data = logit_grad->data(); const T* out_grad_data = out_grad->data(); @@ -180,7 +175,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { } return; } - // for use_softmax=False, continue if (soft_label) { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index b58e9050402bb7d584e0b5e9215a3af54718aa3b..5319fc5c00560b967678fdef073b5a030dce7dd1 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -42,6 +42,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/new_exec.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -1935,6 +1936,34 @@ All parameter, weight, gradient are variables in Paddle. fetch_vars); }); + py::class_(m, "InterpreterCore") + .def(py::init()) + .def("run", + [](InterpreterCore &self, + const std::unordered_map &input_dict, + std::vector vec_fetch_name) { + pybind11::gil_scoped_release release; + std::vector vec_tensor; + std::vector vec_name; + + for (auto &item : input_dict) { + framework::LoDTensor t; + SetTensorFromPyArray( + &t, item.second, platform::CPUPlace(), false); + vec_name.push_back(item.first); + vec_tensor.push_back(t); + } + + std::vector vec_out; + self.run(vec_name, vec_tensor, vec_fetch_name, &vec_out); + std::vector vec_ret; + for (size_t i = 0; i < vec_out.size(); ++i) { + vec_ret.push_back(TensorToPyArray(vec_out[i], true)); + } + return vec_ret; + }); + m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); m.def("load_op_meta_info_and_register_op", diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e7172507696ec0dc750440d5f7d755d6c650baf7..007221ca4f9ca30b93cf3661889d9244a1c8ade4 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -686,6 +686,8 @@ add_subdirectory(asp) add_subdirectory(ir) +add_subdirectory(interpreter) + if (WITH_TESTING) set_property(TEST test_parallel_executor_mnist PROPERTY ENVIRONMENT GLOG_vmodule=all_reduce_deps_pass=10) set_property(TEST test_parallel_executor_fix_op_run_order PROPERTY ENVIRONMENT GLOG_vmodule=fix_op_run_order_pass=10) diff --git a/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt b/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7692f8befdf58ceb6c0a23ebe3e2b49fc656ec3e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_INTERP_CASES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") + +foreach(target ${TEST_INTERP_CASES}) + py_test_modules(${target} MODULES ${target}) +endforeach() diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_interpreter.py b/python/paddle/fluid/tests/unittests/interpreter/test_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..bb18d28e48b67d1ba959583bc90b18d329c4e201 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/interpreter/test_interpreter.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle.fluid import core +from paddle.fluid.core import InterpreterCore + +import numpy as np + +paddle.enable_static() + + +class LinearTestCase(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + def test_interp_base(self): + a = paddle.static.data(name="a", shape=[2, 2], dtype='float32') + b = paddle.ones([2, 2]) * 2 + t = paddle.static.nn.fc(a, 2) + c = t + b + + main_program = paddle.fluid.default_main_program() + startup_program = paddle.fluid.default_startup_program() + p = core.Place() + p.set_place(self.place) + inter_core = InterpreterCore(p, main_program.desc, startup_program.desc, + core.Scope()) + + out = inter_core.run({ + "a": np.ones( + [2, 2], dtype="float32") * 2 + }, [c.name]) + for i in range(10): + out = inter_core.run({ + "a": np.ones( + [2, 2], dtype="float32") * i + }, [c.name]) + + +if __name__ == "__main__": + unittest.main()