From ce1d98e083017afadac9fcd9f94f5c59aceaf6c0 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 10:31:45 +0800 Subject: [PATCH] Add a Tensor to use as a Function argument --- paddle/math/Function.h | 12 +++++++- paddle/math/cross_map_normal_op.cpp | 37 +++++++++++------------- paddle/math/tests/test_matrixCompare.cpp | 9 ++++-- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/math/Function.h b/paddle/math/Function.h index b41ba2a13d..539759782b 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -40,7 +40,17 @@ struct MatrixT { using type = GpuMatrix; }; -typedef std::vector Arguments; +typedef std::vector Dims; + +class Tensor { +public: + Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + + real* buf_; + Dims dims_; +}; + +typedef std::vector Arguments; class FuncConfig { public: diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 0b72732063..d55bd78c62 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -144,26 +144,23 @@ public: CHECK_EQ(2, outputs.size()); CHECK_EQ(0, inouts.size()); - auto input = dynamic_cast::type&>(inputs[0]); - auto output = - dynamic_cast::type&>(outputs[0]); - auto denom = - dynamic_cast::type&>(outputs[1]); - - CHECK(input.isContiguous()); - CHECK(output.isContiguous()); - CHECK(denom.isContiguous()); - CHECK_EQ(output.getHeight(), input.getHeight()); - CHECK_EQ(output.getWidth(), input.getWidth()); - CHECK_EQ(output.getHeight(), denom.getHeight()); - CHECK_EQ(output.getWidth(), denom.getWidth()); - - // CrossMapNormal cross; - // need: - // size_t channels, - // size_t imgSizeH, - // size_t imgSizeW, - // cross(output, denom, input, ); + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + size_t imageSize = channels * height * width; + CpuMatrix input(inputs[0].buf_, samples, imageSize); + CpuMatrix output(outputs[0].buf_, samples, imageSize); + CpuMatrix denom(outputs[1].buf_, samples, imageSize); + + CrossMapNormal cross; + cross(output, denom, input, channels, height, width, size_, scale_, pow_); } private: diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0b75785528..cd34ea18a7 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd( FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); cpu->init(config); - // cpu->calc(); + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputs.getData(), dims)}, + {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, + {}); +#if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - +#endif CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, -- GitLab