diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 1e57e9a20f3eecfac266d67276347ad4b5b780f9..3a1ffc02151f42a4fe6f103925ab424251ee8d85 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -34,22 +34,26 @@ KernelContext::GetEigenDevice() const { #endif const std::string& OperatorBase::Input(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, + "Input Output Indices could not be nullptr"); auto it = in_out_idxs_->find(name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", name); - if (attrs_.count("input_format") == 0) { - return inputs_[it->second]; + return inputs_.at((size_t)it->second); } else { const auto& input_format = GetAttr>("input_format"); int idx = input_format[it->second]; - return inputs_.at(idx); + return inputs_.at((size_t)idx); } } std::vector OperatorBase::Inputs(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); + PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), + "Input Out Of Range"); return std::vector{ inputs_.begin() + input_format.at(offset), @@ -57,23 +61,25 @@ std::vector OperatorBase::Inputs(const std::string& name) const { } const std::string& OperatorBase::Output(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto it = in_out_idxs_->find(name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", name); - if (attrs_.count("output_format") == 0) { - return outputs_[it->second]; + return outputs_.at((size_t)it->second); } else { const auto& output_format = GetAttr>("output_format"); int idx = output_format[it->second]; - return outputs_.at(idx); + return outputs_.at((size_t)idx); } } std::vector OperatorBase::Outputs(const std::string& name) const { + PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - + PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), + "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset + 1)}; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 6d4f8046136f77ce1516820c37da340949a45652..0a14dc21144153f9a45d5227e54102983c6c2659 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -54,4 +54,3 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) -cc_test(fc_op_test SRCS fc_op_test.cc DEPS fc_op) diff --git a/paddle/operators/fc_op_test.cc b/paddle/operators/fc_op_test.cc deleted file mode 100644 index 796b149afef0dbc76fabeac36023e6425f55e8f5..0000000000000000000000000000000000000000 --- a/paddle/operators/fc_op_test.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include -#include "paddle/framework/op_registry.h" -namespace f = paddle::framework; - -USE_OP_WITHOUT_KERNEL(fc); - -TEST(FC, create) { - for (size_t i = 0; i < 1000000; ++i) { - auto tmp = f::OpRegistry::CreateOp("fc", {"X", "W", "B"}, {"O"}, {}); - ASSERT_NE(tmp, nullptr); - } -} \ No newline at end of file