提交 148bd4d0 编写于 作者: H hedaoyuan

add Layer::createFunction

上级 cee93468
......@@ -102,9 +102,9 @@ protected:
std::vector<bool> markInBackward_;
/// Layer forward function
FunctionBase* forward_;
std::vector<std::shared_ptr<FunctionBase>> forward_;
/// Layer backward function
FunctionBase* backward_;
std::vector<std::shared_ptr<FunctionBase>> backward_;
public:
/**
......@@ -132,6 +132,26 @@ public:
virtual void markAllInputGrad();
protected:
/**
* Create layer function. Function is called in forward or backward.
* \param function, Layer::forward_ or Layer::backward_
* \param name, function name
* \param config, initialization configuration for the function
*/
void createFunction(std::vector<std::shared_ptr<FunctionBase>>& function,
const std::string& name,
const FuncConfig& config) {
if (useGpu_) {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-GPU"));
} else {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-CPU"));
}
auto& func = function.back();
func->init(config);
}
/**
* Notify specified layer the output grad ready.
* Called in the backward function.
......
......@@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
/* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1);
if (useGpu_) {
forward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormal, GPU));
backward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormalGrad, GPU));
} else {
forward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormal, CPU));
backward_ = FunctionBase::funcRegistrar_.createByType(
FUNC_NAME(CrossMapNormalGrad, CPU));
}
forward_->init(
createFunction(
forward_,
"CrossMapNormal",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
backward_->init(
createFunction(
backward_,
"CrossMapNormalGrad",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
return true;
......@@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_->calc(
forward_[0]->calc(
{Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
......@@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
backward_->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)},
{});
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)},
{});
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册