提交 68156c88 编写于 作者: H hedaoyuan

Modify the argument type of Function

上级 c5c80516
...@@ -125,27 +125,25 @@ public: ...@@ -125,27 +125,25 @@ public:
pow_ = config.get<real>("pow"); pow_ = config.get<real>("pow");
} }
void calc(const Arguments& inputs, void calc(const BufferArgs& inputs,
const Arguments& outputs, const BufferArgs& outputs,
const Arguments& inouts) override { const BufferArgs& inouts) override {
CHECK_EQ(1, inputs.size()); CHECK_EQ(1, inputs.size());
CHECK_EQ(2, outputs.size()); CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size()); CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4); CHECK_EQ(inputs[0].shape().ndims(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) { CHECK(inputs[0].shape() == outputs[0].shape());
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); CHECK(inputs[0].shape() == outputs[1].shape());
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
size_t samples = inputs[0].dims_[0]; size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].dims_[1]; size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].dims_[2]; size_t height = inputs[0].shape()[2];
size_t width = inputs[0].dims_[3]; size_t width = inputs[0].shape()[3];
CrossMapNormal<Device>(outputs[0].getData(), CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].getData(), outputs[1].data<real>(),
inputs[0].getData(), inputs[0].data<real>(),
samples, samples,
channels, channels,
height, height,
...@@ -177,31 +175,29 @@ public: ...@@ -177,31 +175,29 @@ public:
pow_ = config.get<real>("pow"); pow_ = config.get<real>("pow");
} }
void calc(const Arguments& inputs, void calc(const BufferArgs& inputs,
const Arguments& outputs, const BufferArgs& outputs,
const Arguments& inouts) override { const BufferArgs& inouts) override {
CHECK_EQ(4, inputs.size()); CHECK_EQ(4, inputs.size());
CHECK_EQ(1, outputs.size()); CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size()); CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4); CHECK_EQ(inputs[0].shape().ndims(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) { CHECK(inputs[0].shape() == inputs[1].shape());
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); CHECK(inputs[0].shape() == inputs[2].shape());
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); CHECK(inputs[0].shape() == inputs[3].shape());
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); CHECK(inputs[0].shape() == outputs[0].shape());
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
} size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t samples = inputs[0].dims_[0]; size_t height = inputs[0].shape()[2];
size_t channels = inputs[0].dims_[1]; size_t width = inputs[0].shape()[3];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3]; CrossMapNormalGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
CrossMapNormalGrad<Device>(outputs[0].getData(), inputs[1].data<real>(),
inputs[0].getData(), inputs[2].data<real>(),
inputs[1].getData(), inputs[3].data<real>(),
inputs[2].getData(),
inputs[3].getData(),
samples, samples,
channels, channels,
height, height,
......
...@@ -16,57 +16,12 @@ limitations under the License. */ ...@@ -16,57 +16,12 @@ limitations under the License. */
#include <map> #include <map>
#include <vector> #include <vector>
#include "BufferArg.h"
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h" #include "paddle/utils/ClassRegistrar.h"
namespace paddle { namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <DeviceType Device>
struct SequenceT;
template <>
struct SequenceT<DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct SequenceT<DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* getData() const { return buf_; }
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
class FuncConfig { class FuncConfig {
public: public:
union value { union value {
...@@ -92,9 +47,9 @@ public: ...@@ -92,9 +47,9 @@ public:
virtual void init(const FuncConfig& config) {} virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs, virtual void calc(const BufferArgs& inputs,
const Arguments& outputs, const BufferArgs& outputs,
const Arguments& inouts) {} const BufferArgs& inouts) {}
static ClassRegistrar<FunctionBase> funcRegistrar_; static ClassRegistrar<FunctionBase> funcRegistrar_;
}; };
......
...@@ -71,11 +71,16 @@ void CMRProjectionNormLayer::forward(PassType passType) { ...@@ -71,11 +71,16 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
forward_[0]->calc(
{Tensor(input->getData(), dims_)}, BufferArgs inputs;
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, BufferArgs outputs;
{}); BufferArgs inouts;
inputs.addArg(*input, shape_);
outputs.addArg(*outV, shape_);
outputs.addArg(*denoms_, shape_);
forward_[0]->calc(inputs, outputs, inouts);
} }
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
...@@ -90,11 +95,14 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { ...@@ -90,11 +95,14 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr localOutV = getOutputValue(); MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
backward_[0]->calc({Tensor(preOutV->getData(), dims_), BufferArgs inputs;
Tensor(localOutV->getData(), dims_), BufferArgs outputs;
Tensor(localGrad->getData(), dims_), BufferArgs inouts;
Tensor(denoms_->getData(), dims_)}, inputs.addArg(*preOutV, shape_);
{Tensor(preOutGrad->getData(), dims_)}, inputs.addArg(*localOutV, shape_);
{}); inputs.addArg(*localGrad, shape_);
inputs.addArg(*denoms_, shape_);
outputs.addArg(*preOutGrad, shape_);
backward_[0]->calc(inputs, outputs, inouts);
} }
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,6 @@ public: ...@@ -41,6 +41,6 @@ public:
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr);
protected: protected:
Dims dims_; TensorShape shape_;
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册