From 68156c88c50aff2c614ecc69b56bd5f814dc30be Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 5 Jan 2017 19:45:12 +0800 Subject: [PATCH] Modify the argument type of Function --- paddle/function/CrossMapNormalOp.cpp | 68 +++++++++---------- paddle/function/Function.h | 53 ++------------- paddle/gserver/layers/NormProjectionLayer.cpp | 30 +++++--- paddle/gserver/layers/NormProjectionLayer.h | 2 +- 4 files changed, 56 insertions(+), 97 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index f13eb78d27d..ec27db9c212 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -125,27 +125,25 @@ public: pow_ = config.get("pow"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { + void calc(const BufferArgs& inputs, + const BufferArgs& outputs, + const BufferArgs& inouts) override { CHECK_EQ(1, inputs.size()); CHECK_EQ(2, outputs.size()); CHECK_EQ(0, inouts.size()); - CHECK_EQ(inputs[0].dims_.size(), 4); - for (size_t i = 0; i < inputs[0].dims_.size(); i++) { - CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); - } + CHECK_EQ(inputs[0].shape().ndims(), 4); + CHECK(inputs[0].shape() == outputs[0].shape()); + CHECK(inputs[0].shape() == outputs[1].shape()); - size_t samples = inputs[0].dims_[0]; - size_t channels = inputs[0].dims_[1]; - size_t height = inputs[0].dims_[2]; - size_t width = inputs[0].dims_[3]; + size_t samples = inputs[0].shape()[0]; + size_t channels = inputs[0].shape()[1]; + size_t height = inputs[0].shape()[2]; + size_t width = inputs[0].shape()[3]; - CrossMapNormal(outputs[0].getData(), - outputs[1].getData(), - inputs[0].getData(), + CrossMapNormal(outputs[0].data(), + outputs[1].data(), + inputs[0].data(), samples, channels, height, @@ -177,31 +175,29 @@ public: pow_ = config.get("pow"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { + void calc(const BufferArgs& inputs, + const BufferArgs& outputs, + const BufferArgs& inouts) override { CHECK_EQ(4, inputs.size()); CHECK_EQ(1, outputs.size()); CHECK_EQ(0, inouts.size()); - CHECK_EQ(inputs[0].dims_.size(), 4); - for (size_t i = 0; i < inputs[0].dims_.size(); i++) { - CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); - } - - size_t samples = inputs[0].dims_[0]; - size_t channels = inputs[0].dims_[1]; - size_t height = inputs[0].dims_[2]; - size_t width = inputs[0].dims_[3]; - - CrossMapNormalGrad(outputs[0].getData(), - inputs[0].getData(), - inputs[1].getData(), - inputs[2].getData(), - inputs[3].getData(), + CHECK_EQ(inputs[0].shape().ndims(), 4); + CHECK(inputs[0].shape() == inputs[1].shape()); + CHECK(inputs[0].shape() == inputs[2].shape()); + CHECK(inputs[0].shape() == inputs[3].shape()); + CHECK(inputs[0].shape() == outputs[0].shape()); + + size_t samples = inputs[0].shape()[0]; + size_t channels = inputs[0].shape()[1]; + size_t height = inputs[0].shape()[2]; + size_t width = inputs[0].shape()[3]; + + CrossMapNormalGrad(outputs[0].data(), + inputs[0].data(), + inputs[1].data(), + inputs[2].data(), + inputs[3].data(), samples, channels, height, diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 9e8cbb8e48c..024575b4f7b 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -16,57 +16,12 @@ limitations under the License. */ #include #include +#include "BufferArg.h" #include "paddle/math/Matrix.h" #include "paddle/utils/ClassRegistrar.h" namespace paddle { -enum DeviceType { - DEVICE_TYPE_UNSPECIFIED = 0, - DEVICE_TYPE_CPU = 1, - DEVICE_TYPE_GPU = 2, -}; - -template -struct MatrixT; - -template <> -struct MatrixT { - using type = CpuMatrix; -}; - -template <> -struct MatrixT { - using type = GpuMatrix; -}; - -template -struct SequenceT; - -template <> -struct SequenceT { - using type = CpuIVector; -}; - -template <> -struct SequenceT { - using type = GpuIVector; -}; - -typedef std::vector Dims; - -class Tensor { -public: - Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} - - real* getData() const { return buf_; } - - real* buf_; - Dims dims_; -}; - -typedef std::vector Arguments; - class FuncConfig { public: union value { @@ -92,9 +47,9 @@ public: virtual void init(const FuncConfig& config) {} - virtual void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) {} + virtual void calc(const BufferArgs& inputs, + const BufferArgs& outputs, + const BufferArgs& inouts) {} static ClassRegistrar funcRegistrar_; }; diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 262d757c67e..573de152fd0 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -71,11 +71,16 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; - forward_[0]->calc( - {Tensor(input->getData(), dims_)}, - {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, - {}); + shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_}); + + BufferArgs inputs; + BufferArgs outputs; + BufferArgs inouts; + inputs.addArg(*input, shape_); + outputs.addArg(*outV, shape_); + outputs.addArg(*denoms_, shape_); + + forward_[0]->calc(inputs, outputs, inouts); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { @@ -90,11 +95,14 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - backward_[0]->calc({Tensor(preOutV->getData(), dims_), - Tensor(localOutV->getData(), dims_), - Tensor(localGrad->getData(), dims_), - Tensor(denoms_->getData(), dims_)}, - {Tensor(preOutGrad->getData(), dims_)}, - {}); + BufferArgs inputs; + BufferArgs outputs; + BufferArgs inouts; + inputs.addArg(*preOutV, shape_); + inputs.addArg(*localOutV, shape_); + inputs.addArg(*localGrad, shape_); + inputs.addArg(*denoms_, shape_); + outputs.addArg(*preOutGrad, shape_); + backward_[0]->calc(inputs, outputs, inouts); } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 6b2c5dde0d7..2c0d8a3a718 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -41,6 +41,6 @@ public: void backward(const UpdateCallback& callback = nullptr); protected: - Dims dims_; + TensorShape shape_; }; } // namespace paddle -- GitLab