提交 1e6c917e 编写于 作者: H hedaoyuan

fix unit test of paramRelu

上级 7df67bae
......@@ -1311,7 +1311,9 @@ void GpuMatrix::paramReluForward(Matrix& data, Matrix& W) {
real* w = W.getData();
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t partial_sum = numElements / (W.getHeight() * W.getWidth());
size_t paraSize = W.getHeight() * W.getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
real* output = getData();
hl_param_relu_forward(output, input, w, numElements, numSamples, partial_sum);
}
......@@ -1324,7 +1326,9 @@ void GpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) {
real* wgrad = data_;
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t partial_sum = numElements / (this->getHeight() * this->getWidth());
size_t paraSize = this->getHeight() * this->getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
hl_param_relu_backward_w(
wgrad, ograd, input, numElements, numSamples, partial_sum);
}
......@@ -1336,7 +1340,9 @@ void GpuMatrix::paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
real* w = W.getData();
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t partial_sum = numElements / (W.getHeight() * W.getWidth());
size_t paraSize = W.getHeight() * W.getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
hl_param_relu_backward_diff(
ograd, input, w, diff, numElements, numSamples, partial_sum);
}
......@@ -3764,7 +3770,9 @@ void CpuMatrix::paramReluForward(Matrix& data, Matrix& W) {
real* w = W.getData();
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t partial_sum = numElements / (W.getHeight() * W.getWidth());
size_t paraSize = W.getHeight() * W.getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
for (size_t n = 0, k = 0; n < numSamples; ++n) {
for (size_t i = 0; i < numElements; ++i, ++k) {
data_[k] = input[k] > 0 ? input[k] : input[k] * w[i / partial_sum];
......@@ -3778,7 +3786,9 @@ void CpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) {
real* wgrad = data_;
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t partial_sum = numElements / (this->getHeight() * this->getWidth());
size_t paraSize = this->getHeight() * this->getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
for (size_t n = 0, k = 0; n < numSamples; ++n) {
for (size_t i = 0; i < numElements; ++i, ++k) {
wgrad[i / partial_sum] += ograd[k] * (input[k] > 0 ? 0 : input[k]);
......@@ -3793,7 +3803,9 @@ void CpuMatrix::paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
real* w = W.getData();
size_t numElements = data.getWidth();
size_t numSamples = data.getHeight();
size_t partial_sum = numElements / (W.getHeight() * W.getWidth());
size_t paraSize = W.getHeight() * W.getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize;
for (size_t n = 0, k = 0; n < numSamples; ++n) {
for (size_t i = 0; i < numElements; ++i, ++k) {
diff[k] += ograd[k] * (input[k] > 0 ? 1 : w[i / partial_sum]);
......
......@@ -224,10 +224,11 @@ void testParamReluBackwardW(int height, int width, int w_height, int w_width) {
}
TEST(Matrix, paramRelu) {
for (auto height : {10, 100}) {
for (auto width : {10, 100}) {
for (auto height : {10, 40, 100}) {
for (auto width : {10, 40, 100}) {
for (auto w_height : {1, 2}) {
for (auto w_width : {1, 2}) {
if (width % (w_height * w_width)) continue;
testParamReluForward(height, width, w_height, w_width);
testParamReluBackwardW(height, width, w_height, w_width);
}
......
......@@ -773,10 +773,11 @@ void testParamReluBackwardDiff(int height,
}
TEST(Matrix, paramReluBackwardDiff) {
for (auto height : {10, 100}) {
for (auto width : {10, 100}) {
for (auto height : {10, 40, 100}) {
for (auto width : {10, 40, 100}) {
for (auto w_height : {1, 2}) {
for (auto w_width : {1, 2}) {
if (width % (w_height * w_width)) continue;
testParamReluBackwardDiff(height, width, w_height, w_width);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册