diff --git a/paddle/math/Function.h b/paddle/math/Function.h index b41ba2a13d3774932fd9efe36be77669eba4a0f3..539759782be3a186d3bbd6f90296c4110f6ddb0a 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -40,7 +40,17 @@ struct MatrixT { using type = GpuMatrix; }; -typedef std::vector Arguments; +typedef std::vector Dims; + +class Tensor { +public: + Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + + real* buf_; + Dims dims_; +}; + +typedef std::vector Arguments; class FuncConfig { public: diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 0b7273206381d267f346e208122479190a762423..d55bd78c628f734fe4185d09e7b4d8c1aaa820d6 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -144,26 +144,23 @@ public: CHECK_EQ(2, outputs.size()); CHECK_EQ(0, inouts.size()); - auto input = dynamic_cast::type&>(inputs[0]); - auto output = - dynamic_cast::type&>(outputs[0]); - auto denom = - dynamic_cast::type&>(outputs[1]); - - CHECK(input.isContiguous()); - CHECK(output.isContiguous()); - CHECK(denom.isContiguous()); - CHECK_EQ(output.getHeight(), input.getHeight()); - CHECK_EQ(output.getWidth(), input.getWidth()); - CHECK_EQ(output.getHeight(), denom.getHeight()); - CHECK_EQ(output.getWidth(), denom.getWidth()); - - // CrossMapNormal cross; - // need: - // size_t channels, - // size_t imgSizeH, - // size_t imgSizeW, - // cross(output, denom, input, ); + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + size_t imageSize = channels * height * width; + CpuMatrix input(inputs[0].buf_, samples, imageSize); + CpuMatrix output(outputs[0].buf_, samples, imageSize); + CpuMatrix denom(outputs[1].buf_, samples, imageSize); + + CrossMapNormal cross; + cross(output, denom, input, channels, height, width, size_, scale_, pow_); } private: diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0b75785528f5d8da8503a54ca8280aba8203f377..cd34ea18a70ea29b61b261fc36cb9bb11364ea80 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd( FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); cpu->init(config); - // cpu->calc(); + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputs.getData(), dims)}, + {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, + {}); +#if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - +#endif CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu,