提交 5edbe32f 编写于 作者: H hedaoyuan 提交者: GitHub

Merge pull request #1216 from hedaoyuan/cmrnorm

Function Adds some properties
......@@ -162,38 +162,64 @@ template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
// number of inputs and outputs
numInputs_ = 1;
numOutputs_ = 2;
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
check(inputs, outputs);
// ArgType check still on here,
// not sure whether it is better to put inside the check.
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];
CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].data<real>(),
inputs[0].data<real>(),
samples,
channels,
height,
width,
batchSize,
maps,
rows,
columns,
size_,
scale_,
pow_);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
}
// Only need the shape of the input, can calculate the
// floating-point operation.
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)numInputs_, inputs.size());
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];
// number of floating-point operations
// an approximate value
size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3);
return ops;
}
private:
size_t size_;
real scale_;
......@@ -236,21 +262,18 @@ template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
// number of inputs and outputs
numInputs_ = 4;
numOutputs_ = 1;
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)4, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());
check(inputs, outputs);
if (outputs[0].getArgType() != ADD_TO) {
// Currently, some algorithm implementations are ASSIGN_TO mode,
// if need to support the ADD_TO calculation, need to clear the output.
......@@ -259,25 +282,52 @@ public:
tmp.zero();
}
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inputs[1].data<real>(),
inputs[2].data<real>(),
inputs[3].data<real>(),
samples,
channels,
height,
width,
batchSize,
maps,
rows,
columns,
size_,
scale_,
pow_);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());
}
// Only need the shape of one input, can calculate the
// floating-point operation.
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_LT((size_t)1, inputs.size());
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];
// number of floating-point operations
// an approximate value
size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2);
return ops;
}
private:
size_t size_;
real scale_;
......
......@@ -153,7 +153,36 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
// This member function is used to check whether the BufferType and shape of
// the inputs and outputs arguments of the Function are correct.
// General calc function which will call this check to do arguments check.
// And before the calc called, the caller can also check their own arguments.
virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}
// Calculate the number of floating-point operations of this Function.
// The inputs and outputs arguments do not need to contain the actual data,
// only the shape.
// And some Functions have the same input and output shapes,
// so you may not need to enter the complete number of arguments.
// But entering the full arguments is always correct for this interface.
virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
return 0;
}
int getNumInputs() const { return numInputs_; }
int getNumOutputs() const { return numOutputs_; }
static ClassRegistrar<FunctionBase> funcRegistrar_;
protected:
// numInputs_ and numOutputs_ represents the maximum
// input and output supported by Function.
// Some functions are optimized for input and output,
// so when comparing the number of arguments, for these functions
// inputs.size() <= numInputs_ or outputs.size() <= numOutputs_
size_t numInputs_;
size_t numOutputs_;
};
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册