提交 ce1d98e0 编写于 作者: H hedaoyuan

Add a Tensor to use as a Function argument

上级 a1d2abc1
...@@ -40,7 +40,17 @@ struct MatrixT<DEVICE_TYPE_GPU> { ...@@ -40,7 +40,17 @@ struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix; using type = GpuMatrix;
}; };
typedef std::vector<Matrix> Arguments; typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
class FuncConfig { class FuncConfig {
public: public:
......
...@@ -144,26 +144,23 @@ public: ...@@ -144,26 +144,23 @@ public:
CHECK_EQ(2, outputs.size()); CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size()); CHECK_EQ(0, inouts.size());
auto input = dynamic_cast<const typename MatrixT<Device>::type&>(inputs[0]); CHECK_EQ(inputs[0].dims_.size(), 4);
auto output = for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
dynamic_cast<const typename MatrixT<Device>::type&>(outputs[0]); CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
auto denom = CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
dynamic_cast<const typename MatrixT<Device>::type&>(outputs[1]); }
CHECK(input.isContiguous()); size_t samples = inputs[0].dims_[0];
CHECK(output.isContiguous()); size_t channels = inputs[0].dims_[1];
CHECK(denom.isContiguous()); size_t height = inputs[0].dims_[2];
CHECK_EQ(output.getHeight(), input.getHeight()); size_t width = inputs[0].dims_[3];
CHECK_EQ(output.getWidth(), input.getWidth()); size_t imageSize = channels * height * width;
CHECK_EQ(output.getHeight(), denom.getHeight()); CpuMatrix input(inputs[0].buf_, samples, imageSize);
CHECK_EQ(output.getWidth(), denom.getWidth()); CpuMatrix output(outputs[0].buf_, samples, imageSize);
CpuMatrix denom(outputs[1].buf_, samples, imageSize);
// CrossMapNormal<Device> cross;
// need: CrossMapNormal<Device> cross;
// size_t channels, cross(output, denom, input, channels, height, width, size_, scale_, pow_);
// size_t imgSizeH,
// size_t imgSizeW,
// cross(output, denom, input, );
} }
private: private:
......
...@@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd( ...@@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd(
FunctionBase* cpu = FunctionBase* cpu =
FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU));
cpu->init(config); cpu->init(config);
// cpu->calc();
Dims dims{
(size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW};
cpu->calc({Tensor(inputs.getData(), dims)},
{Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)},
{});
#if 0
CrossMapNormal<DEVICE_TYPE_CPU> cpuCross; CrossMapNormal<DEVICE_TYPE_CPU> cpuCross;
cpuCross( cpuCross(
outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow);
#endif
CrossMapNormal<DEVICE_TYPE_GPU> gpuCross; CrossMapNormal<DEVICE_TYPE_GPU> gpuCross;
gpuCross(outputsGpu, gpuCross(outputsGpu,
denomsGpu, denomsGpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册