提交 ce1d98e0 编写于 作者: H hedaoyuan

Add a Tensor to use as a Function argument

上级 a1d2abc1
......@@ -40,7 +40,17 @@ struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
typedef std::vector<Matrix> Arguments;
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
class FuncConfig {
public:
......
......@@ -144,26 +144,23 @@ public:
CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size());
auto input = dynamic_cast<const typename MatrixT<Device>::type&>(inputs[0]);
auto output =
dynamic_cast<const typename MatrixT<Device>::type&>(outputs[0]);
auto denom =
dynamic_cast<const typename MatrixT<Device>::type&>(outputs[1]);
CHECK(input.isContiguous());
CHECK(output.isContiguous());
CHECK(denom.isContiguous());
CHECK_EQ(output.getHeight(), input.getHeight());
CHECK_EQ(output.getWidth(), input.getWidth());
CHECK_EQ(output.getHeight(), denom.getHeight());
CHECK_EQ(output.getWidth(), denom.getWidth());
// CrossMapNormal<Device> cross;
// need:
// size_t channels,
// size_t imgSizeH,
// size_t imgSizeW,
// cross(output, denom, input, );
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
size_t imageSize = channels * height * width;
CpuMatrix input(inputs[0].buf_, samples, imageSize);
CpuMatrix output(outputs[0].buf_, samples, imageSize);
CpuMatrix denom(outputs[1].buf_, samples, imageSize);
CrossMapNormal<Device> cross;
cross(output, denom, input, channels, height, width, size_, scale_, pow_);
}
private:
......
......@@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd(
FunctionBase* cpu =
FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU));
cpu->init(config);
// cpu->calc();
Dims dims{
(size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW};
cpu->calc({Tensor(inputs.getData(), dims)},
{Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)},
{});
#if 0
CrossMapNormal<DEVICE_TYPE_CPU> cpuCross;
cpuCross(
outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow);
#endif
CrossMapNormal<DEVICE_TYPE_GPU> gpuCross;
gpuCross(outputsGpu,
denomsGpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册