提交 148bd4d0 编写于 作者: H hedaoyuan

add Layer::createFunction

上级 cee93468
...@@ -102,9 +102,9 @@ protected: ...@@ -102,9 +102,9 @@ protected:
std::vector<bool> markInBackward_; std::vector<bool> markInBackward_;
/// Layer forward function /// Layer forward function
FunctionBase* forward_; std::vector<std::shared_ptr<FunctionBase>> forward_;
/// Layer backward function /// Layer backward function
FunctionBase* backward_; std::vector<std::shared_ptr<FunctionBase>> backward_;
public: public:
/** /**
...@@ -132,6 +132,26 @@ public: ...@@ -132,6 +132,26 @@ public:
virtual void markAllInputGrad(); virtual void markAllInputGrad();
protected: protected:
/**
* Create layer function. Function is called in forward or backward.
* \param function, Layer::forward_ or Layer::backward_
* \param name, function name
* \param config, initialization configuration for the function
*/
void createFunction(std::vector<std::shared_ptr<FunctionBase>>& function,
const std::string& name,
const FuncConfig& config) {
if (useGpu_) {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-GPU"));
} else {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-CPU"));
}
auto& func = function.back();
func->init(config);
}
/** /**
* Notify specified layer the output grad ready. * Notify specified layer the output grad ready.
* Called in the backward function. * Called in the backward function.
......
...@@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, ...@@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
/* the size of inputs for norm-layer is 1 */ /* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1); CHECK_EQ(config_.inputs_size(), 1);
if (useGpu_) { createFunction(
forward_ = FunctionBase::funcRegistrar_.createByType( forward_,
FUNC_NAME(CrossMapNormal, GPU)); "CrossMapNormal",
backward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormalGrad, GPU));
} else {
forward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormal, CPU));
backward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormalGrad, CPU));
}
forward_->init(
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
createFunction(
backward_->init( backward_,
"CrossMapNormalGrad",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
return true; return true;
...@@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { ...@@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_->calc( forward_[0]->calc(
{Tensor(input->getData(), dims_)}, {Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{}); {});
...@@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { ...@@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr localOutV = getOutputValue(); MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
backward_->calc({Tensor(preOutV->getData(), dims_), backward_[0]->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_), Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_), Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)}, Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)}, {Tensor(preOutGrad->getData(), dims_)},
{}); {});
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册