提交 36524bb2 编写于 作者: Y Yu Yang

Add Error in FuncConfig.

* Also test std::vector
* Use std::vector to PadConf
上级 7217e834
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/utils/Any.h" #include "paddle/utils/Any.h"
#include "paddle/utils/ClassRegistrar.h" #include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
...@@ -30,12 +31,30 @@ namespace paddle { ...@@ -30,12 +31,30 @@ namespace paddle {
class FuncConfig { class FuncConfig {
public: public:
template <typename T> template <typename T>
T get(const std::string& key) const { T get(const std::string& key, Error* err = nullptr) const {
return any_cast<T>(valueMap_[key]); try {
return any_cast<T>(valueMap_.at(key));
} catch (std::exception& e) { // could be cast or out of range exception.
if (err) {
*err = Error(e.what());
} else {
LOG(FATAL) << "Cannot get key " << key << "with error " << e.what();
}
return T();
}
} }
template <typename T> template <typename T>
FuncConfig& set(const std::string& key, T v) { FuncConfig& set(const std::string& key, T v, Error* err = nullptr) {
auto it = valueMap_.find(key);
if (it != valueMap_.end()) { // already contains key.
if (err) {
*err = Error("Key %s is already set in FuncConfig", key.c_str());
} else {
LOG(FATAL) << "Key " << key << " is already set in FuncConfig.";
}
return *this;
}
valueMap_[key] = any(v); valueMap_[key] = any(v);
return *this; return *this;
} }
......
...@@ -25,9 +25,9 @@ void Pad<DEVICE_TYPE_CPU>(real* outputs, ...@@ -25,9 +25,9 @@ void Pad<DEVICE_TYPE_CPU>(real* outputs,
const int inH, const int inH,
const int inW, const int inW,
const PadConf& pad) { const PadConf& pad) {
int cstart = pad.channelStart, cend = pad.channelEnd; int cstart = pad.channel[0], cend = pad.channel[1];
int hstart = pad.heightStart, hend = pad.heightEnd; int hstart = pad.height[0], hend = pad.height[1];
int wstart = pad.widthStart, wend = pad.widthEnd; int wstart = pad.width[0], wend = pad.width[1];
int outC = inC + cstart + cend; int outC = inC + cstart + cend;
int outH = inH + hstart + hend; int outH = inH + hstart + hend;
int outW = inW + wstart + wend; int outW = inW + wstart + wend;
...@@ -51,9 +51,9 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad, ...@@ -51,9 +51,9 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
const int inH, const int inH,
const int inW, const int inW,
const PadConf& pad) { const PadConf& pad) {
int cstart = pad.channelStart, cend = pad.channelEnd; int cstart = pad.channel[0], cend = pad.channel[1];
int hstart = pad.heightStart, hend = pad.heightEnd; int hstart = pad.height[0], hend = pad.height[1];
int wstart = pad.widthStart, wend = pad.widthEnd; int wstart = pad.width[0], wend = pad.width[1];
int outC = inC + cstart + cend; int outC = inC + cstart + cend;
int outH = inH + hstart + hend; int outH = inH + hstart + hend;
int outW = inW + wstart + wend; int outW = inW + wstart + wend;
...@@ -71,6 +71,12 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad, ...@@ -71,6 +71,12 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
} }
} }
static inline PadConf castToPadConf(const FuncConfig& conf) {
return {conf.get<std::vector<uint32_t>>("channel"),
conf.get<std::vector<uint32_t>>("height"),
conf.get<std::vector<uint32_t>>("width")};
}
/** /**
* \brief Padding zeros to input according to the specify dimension. * \brief Padding zeros to input according to the specify dimension.
* The struct pad_ contains the padding size in each dimension. * The struct pad_ contains the padding size in each dimension.
...@@ -127,14 +133,7 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad, ...@@ -127,14 +133,7 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
template <DeviceType Device> template <DeviceType Device>
class PadFunc : public FunctionBase { class PadFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override { pad_ = castToPadConf(config); }
pad_.channelStart = config.get<int>("cstart");
pad_.channelEnd = config.get<int>("cend");
pad_.heightStart = config.get<int>("hstart");
pad_.heightEnd = config.get<int>("hend");
pad_.widthStart = config.get<int>("wstart");
pad_.widthEnd = config.get<int>("wend");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, inputs.size());
...@@ -175,14 +174,7 @@ private: ...@@ -175,14 +174,7 @@ private:
template <DeviceType Device> template <DeviceType Device>
class PadGradFunc : public FunctionBase { class PadGradFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override { pad_ = castToPadConf(config); }
pad_.channelStart = config.get<int>("cstart");
pad_.channelEnd = config.get<int>("cend");
pad_.heightStart = config.get<int>("hstart");
pad_.heightEnd = config.get<int>("hend");
pad_.widthStart = config.get<int>("wstart");
pad_.widthEnd = config.get<int>("wend");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, inputs.size());
......
...@@ -19,18 +19,12 @@ limitations under the License. */ ...@@ -19,18 +19,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
struct PadConf { struct PadConf {
/// how many values to add before the data along channel dimension. /// how many values to add before/after the data along channel dimension.
int channelStart; std::vector<uint32_t> channel;
/// how many values to add after the data along channel dimension. /// how many values to add before/after the data along height dimension.
int channelEnd; std::vector<uint32_t> height;
/// how many values to add before the data along height dimension. /// how many values to add before/after the data along width dimension.
int heightStart; std::vector<uint32_t> width;
/// how many values to add after the data along height dimension.
int heightEnd;
/// how many values to add before the data along width dimension.
int widthStart;
/// how many values to add after the data along width dimension.
int widthEnd;
}; };
/** /**
......
...@@ -36,12 +36,9 @@ bool PadLayer::init(const LayerMap& layerMap, ...@@ -36,12 +36,9 @@ bool PadLayer::init(const LayerMap& layerMap,
CHECK_EQ(2, pad_conf.pad_c_size()); CHECK_EQ(2, pad_conf.pad_c_size());
CHECK_EQ(2, pad_conf.pad_h_size()); CHECK_EQ(2, pad_conf.pad_h_size());
CHECK_EQ(2, pad_conf.pad_w_size()); CHECK_EQ(2, pad_conf.pad_w_size());
padc_.push_back(pad_conf.pad_c(0)); padc_ = {pad_conf.pad_c(0), pad_conf.pad_c(1)};
padc_.push_back(pad_conf.pad_c(1)); padh_ = {pad_conf.pad_h(0), pad_conf.pad_h(1)};
padh_.push_back(pad_conf.pad_h(0)); padw_ = {pad_conf.pad_w(0), pad_conf.pad_w(1)};
padh_.push_back(pad_conf.pad_h(1));
padw_.push_back(pad_conf.pad_w(0));
padw_.push_back(pad_conf.pad_w(1));
outDims_ = TensorShape(4); outDims_ = TensorShape(4);
setOutDims(0); setOutDims(0);
...@@ -49,21 +46,15 @@ bool PadLayer::init(const LayerMap& layerMap, ...@@ -49,21 +46,15 @@ bool PadLayer::init(const LayerMap& layerMap,
createFunction(forward_, createFunction(forward_,
"Pad", "Pad",
FuncConfig() FuncConfig()
.set("cstart", padc_[0]) .set("channel", padc_)
.set("cend", padc_[1]) .set("height", padh_)
.set("hstart", padh_[0]) .set("width", padw_));
.set("hend", padh_[1])
.set("wstart", padw_[0])
.set("wend", padw_[1]));
createFunction(backward_, createFunction(backward_,
"PadGrad", "PadGrad",
FuncConfig() FuncConfig()
.set("cstart", padc_[0]) .set("channel", padc_)
.set("cend", padc_[1]) .set("height", padh_)
.set("hstart", padh_[0]) .set("width", padw_));
.set("hend", padh_[1])
.set("wstart", padw_[0])
.set("wend", padw_[1]));
return true; return true;
} }
......
...@@ -38,9 +38,9 @@ protected: ...@@ -38,9 +38,9 @@ protected:
void setOutDims(const size_t batchSize); void setOutDims(const size_t batchSize);
void setTensorDim(const size_t batchSize); void setTensorDim(const size_t batchSize);
std::vector<int> padc_; std::vector<uint32_t> padc_;
std::vector<int> padh_; std::vector<uint32_t> padh_;
std::vector<int> padw_; std::vector<uint32_t> padw_;
TensorShape inDims_; TensorShape inDims_;
TensorShape outDims_; TensorShape outDims_;
}; };
......
...@@ -20,6 +20,7 @@ namespace paddle { ...@@ -20,6 +20,7 @@ namespace paddle {
// using std::any for C++ 17 // using std::any for C++ 17
using std::any; using std::any;
using std::any_cast; using std::any_cast;
using std::bad_any_cast;
} // namespace paddle } // namespace paddle
#else #else
...@@ -29,5 +30,6 @@ namespace paddle { ...@@ -29,5 +30,6 @@ namespace paddle {
// use linb::any for C++ 11 // use linb::any for C++ 11
using linb::any; using linb::any;
using linb::any_cast; using linb::any_cast;
using linb::bad_any_cast;
} // namespace paddle } // namespace paddle
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册