// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by Jiabin on 2019-08-19. // #include #include #include #include #include "gtest/gtest.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/imperative/type_defs.h" namespace imperative = paddle::imperative; namespace platform = paddle::platform; namespace framework = paddle::framework; namespace paddle { namespace imperative { static framework::RuntimeContext PrepareRuntimeContext( const NameVarBaseMap& ins, const NameVarBaseMap& outs) { framework::VariableValueMap inputs, outputs; for (auto& in_pair : ins) { auto& in_ctx = inputs[in_pair.first]; in_ctx.reserve(in_pair.second.size()); for (auto& in_var : in_pair.second) { in_ctx.emplace_back(in_var->MutableVar()); } } for (auto& out_pair : outs) { auto& out_ctx = outputs[out_pair.first]; out_ctx.reserve(out_pair.second.size()); for (auto& out_var : out_pair.second) { out_ctx.emplace_back(out_var->MutableVar()); } } return framework::RuntimeContext(std::move(inputs), std::move(outputs)); } static framework::VariableNameMap CreateVarNameMap( const framework::OpInfo& op_info, const std::string& op_type, const NameVarBaseMap& varbase_map, bool is_input) { if (op_info.proto_ == nullptr) { return {}; } framework::VariableNameMap result; for (auto& var : is_input ? op_info.Proto().inputs() : op_info.Proto().outputs()) { auto it = varbase_map.find(var.name()); if (it == varbase_map.end()) { PADDLE_ENFORCE_EQ( var.dispensable(), true, "Var: %s not dispensable and there are no such var in inputs", var.name()); result[var.name()] = {}; } else { auto& var_vector = it->second; std::vector args; args.reserve(var_vector.size()); for (auto& var_base : var_vector) { args.emplace_back(var_base->Name()); } result[var.name()] = std::move(args); } } return result; } using vb_vector = std::vector>; using var_pair = std::pair; TEST(test_prepare_op, test_prepare_op) { std::shared_ptr vin( new imperative::VarBase(false, "vin")); std::shared_ptr vout( new imperative::VarBase(false, "vout")); framework::OpDesc desc; platform::CPUPlace place; vin->MutableVar()->GetMutable()->mutable_data( place); var_pair x_pair = var_pair("X", vb_vector(1, vin)); var_pair out_pair = var_pair("Out", vb_vector(1, vout)); imperative::NameVarBaseMap ins = {x_pair}; imperative::NameVarBaseMap outs = {out_pair}; framework::AttributeMap split_attr_map; const auto& info = framework::OpInfoMap::Instance().Get("split"); framework::VariableNameMap var_in_map = CreateVarNameMap(info, "split", ins, true); framework::VariableNameMap var_out_map = CreateVarNameMap(info, "split", outs, false); framework::OperatorWithKernel op("split", var_in_map, var_out_map, split_attr_map); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(ctx, op, place)); } const framework::Tensor* GetTensorFromVar(const framework::Variable& var); TEST(test_prepare_op, test_get_tensor_from_var) { std::shared_ptr vout_error( new imperative::VarBase(false, "vout_error")); vout_error->MutableVar()->GetMutable(); auto* ts = GetTensorFromVar(*vout_error->MutableVar()); ASSERT_TRUE(ts != nullptr); } } // namespace imperative } // namespace paddle USE_OP(split);