提交 68156c88 编写于 作者: H hedaoyuan

Modify the argument type of Function

上级 c5c80516
......@@ -125,27 +125,25 @@ public:
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
void calc(const BufferArgs& inputs,
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(1, inputs.size());
CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
CHECK_EQ(inputs[0].shape().ndims(), 4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
CrossMapNormal<Device>(outputs[0].getData(),
outputs[1].getData(),
inputs[0].getData(),
CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].data<real>(),
inputs[0].data<real>(),
samples,
channels,
height,
......@@ -177,31 +175,29 @@ public:
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
void calc(const BufferArgs& inputs,
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(4, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormalGrad<Device>(outputs[0].getData(),
inputs[0].getData(),
inputs[1].getData(),
inputs[2].getData(),
inputs[3].getData(),
CHECK_EQ(inputs[0].shape().ndims(), 4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inputs[1].data<real>(),
inputs[2].data<real>(),
inputs[3].data<real>(),
samples,
channels,
height,
......
......@@ -16,57 +16,12 @@ limitations under the License. */
#include <map>
#include <vector>
#include "BufferArg.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <DeviceType Device>
struct SequenceT;
template <>
struct SequenceT<DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct SequenceT<DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* getData() const { return buf_; }
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
class FuncConfig {
public:
union value {
......@@ -92,9 +47,9 @@ public:
virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {}
virtual void calc(const BufferArgs& inputs,
const BufferArgs& outputs,
const BufferArgs& inouts) {}
static ClassRegistrar<FunctionBase> funcRegistrar_;
};
......
......@@ -71,11 +71,16 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_[0]->calc(
{Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
BufferArgs inputs;
BufferArgs outputs;
BufferArgs inouts;
inputs.addArg(*input, shape_);
outputs.addArg(*outV, shape_);
outputs.addArg(*denoms_, shape_);
forward_[0]->calc(inputs, outputs, inouts);
}
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
......@@ -90,11 +95,14 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)},
{});
BufferArgs inputs;
BufferArgs outputs;
BufferArgs inouts;
inputs.addArg(*preOutV, shape_);
inputs.addArg(*localOutV, shape_);
inputs.addArg(*localGrad, shape_);
inputs.addArg(*denoms_, shape_);
outputs.addArg(*preOutGrad, shape_);
backward_[0]->calc(inputs, outputs, inouts);
}
} // namespace paddle
......@@ -41,6 +41,6 @@ public:
void backward(const UpdateCallback& callback = nullptr);
protected:
Dims dims_;
TensorShape shape_;
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册