提交 8e7c3253 编写于 作者: Y Yu Yang

Add Some Checker in Input/Output

上级 9b9449fb
...@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -34,22 +34,26 @@ KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif #endif
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr,
"Input Output Indices could not be nullptr");
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
if (attrs_.count("input_format") == 0) { if (attrs_.count("input_format") == 0) {
return inputs_[it->second]; return inputs_.at((size_t)it->second);
} else { } else {
const auto& input_format = GetAttr<std::vector<int>>("input_format"); const auto& input_format = GetAttr<std::vector<int>>("input_format");
int idx = input_format[it->second]; int idx = input_format[it->second];
return inputs_.at(idx); return inputs_.at((size_t)idx);
} }
} }
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(),
"Input Out Of Range");
return std::vector<std::string>{ return std::vector<std::string>{
inputs_.begin() + input_format.at(offset), inputs_.begin() + input_format.at(offset),
...@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -57,23 +61,25 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
} }
const std::string& OperatorBase::Output(const std::string& name) const { const std::string& OperatorBase::Output(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
if (attrs_.count("output_format") == 0) { if (attrs_.count("output_format") == 0) {
return outputs_[it->second]; return outputs_.at((size_t)it->second);
} else { } else {
const auto& output_format = GetAttr<std::vector<int>>("output_format"); const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second]; int idx = output_format[it->second];
return outputs_.at(idx); return outputs_.at((size_t)idx);
} }
} }
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(),
"Output Out of Range");
return std::vector<std::string>{ return std::vector<std::string>{
outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)}; outputs_.begin() + output_format.at(offset + 1)};
......
...@@ -54,4 +54,3 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op ...@@ -54,4 +54,3 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net) softmax_op net)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
cc_test(fc_op_test SRCS fc_op_test.cc DEPS fc_op)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
namespace f = paddle::framework;
USE_OP_WITHOUT_KERNEL(fc);
TEST(FC, create) {
for (size_t i = 0; i < 1000000; ++i) {
auto tmp = f::OpRegistry::CreateOp("fc", {"X", "W", "B"}, {"O"}, {});
ASSERT_NE(tmp, nullptr);
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册