提交 3ae9aa93 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #4860 from tensor-tang/merge_grad_gtest

enable merge grad unit test
......@@ -21,6 +21,10 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
#ifdef PADDLE_USE_MKLDNN
#include "paddle/gserver/layers/MKLDNNLayer.h"
#endif
#ifndef PADDLE_MOBILE_INFERENCE
#include "MultiNetwork.h"
#include "RecurrentGradientMachine.h"
......@@ -300,6 +304,17 @@ void NeuralNetwork::backward(const UpdateCallback& callback) {
}
}
void NeuralNetwork::finish() {
#ifdef PADDLE_USE_MKLDNN
FOR_EACH_R(layer, layers_) {
MKLDNNLayerPtr dnnLayer = std::dynamic_pointer_cast<MKLDNNLayer>(*layer);
if (dnnLayer) {
dnnLayer->convertWeightsToPaddle();
}
}
#endif
}
Argument NeuralNetwork::getLayerOutput(const std::string& layerName) {
return getLayer(layerName)->getOutput();
}
......
......@@ -134,6 +134,9 @@ public:
const std::string& getName() const { return subModelName_; }
/// some finish work, like convert the weight format of MKLDNNLayers
void finish() override;
protected:
/**
* The constructor of NeuralNetwork.
......
......@@ -313,6 +313,7 @@ void MKLDNNConvLayer::resetOutValue(
cvtOutVal_ = MKLDNNMatrix::createReorder(out, cpuOutVal_);
CHECK(cvtOutVal_) << "should not be empty";
} else {
cpuOut->setData(output_.value->getData());
cpuOutVal_ = out;
}
// when output is cpu device, change the mkldnn output value and make them
......@@ -456,17 +457,18 @@ void MKLDNNConvLayer::resetOutGrad(
MKLDNNLayer::resetOutGrad(out, outVal_->getPrimitiveDesc());
} else {
const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad;
// always share the same grad data of CPU output
// then the activation can get the right grad from output_.grad
output_.grad->setData(cpuOut->getData());
// same PrimitiveDesc with cpuInVal_
CHECK(cpuOutVal_);
cpuOutGrad_ = MKLDNNMatrix::create(cpuOut, cpuOutVal_->getPrimitiveDesc());
// create reorder if primitive desc does not match
if (cpuOutGrad_->getPrimitiveDesc() != outVal_->getPrimitiveDesc()) {
out = MKLDNNMatrix::create(output_.grad, outVal_->getPrimitiveDesc());
out = MKLDNNMatrix::create(nullptr, outVal_->getPrimitiveDesc());
cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out);
CHECK(cvtOutGrad_);
} else {
// share the same data of CPU output
output_.grad->setData(cpuOut->getData());
out = cpuOutGrad_;
}
}
......
......@@ -46,6 +46,9 @@ protected:
// backward also need reset after reset forward handle
bool needResetBwd_;
// is output only mkldnn
bool outputOnlyMKLDNN_;
// mkldnn engine, stream and primivtives
mkldnn::engine engine_;
std::shared_ptr<MKLDNNStream> stream_;
......@@ -141,6 +144,9 @@ public:
updateInputData();
}
if (!outputOnlyMKLDNN_) {
clearGrads();
}
stream_->submit(pipelineFwd_);
}
......@@ -389,7 +395,8 @@ protected:
CHECK_EQ(outputOtherDevice_[i].deviceId, CPU_DEVICE)
<< "Only support other device is CPU yet";
}
return outputOtherDevice_.size() == 0;
outputOnlyMKLDNN_ = outputOtherDevice_.size() == 0;
return outputOnlyMKLDNN_;
}
/**
......@@ -398,6 +405,16 @@ protected:
void setDevice(int id) { deviceId_ = id; }
private:
/**
* clear all grad
*/
void clearGrads() {
output_.grad->zeroMem();
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
outputOtherDevice_[i].grad->zeroMem();
}
}
/**
* Set deviceId of the params used in this layer.
*/
......
......@@ -146,6 +146,7 @@ void MKLDNNPoolLayer::resetOutValue(MKLDNNMatrixPtr& out) {
cvtOutVal_ = MKLDNNMatrix::createReorder(out, cpuOutVal_);
CHECK(cvtOutVal_) << "should not be emptry";
} else {
cpuOut->setData(output_.value->getData());
cpuOutVal_ = out;
}
output_.value = std::dynamic_pointer_cast<Matrix>(cpuOutVal_);
......@@ -213,15 +214,16 @@ void MKLDNNPoolLayer::resetOutGrad(MKLDNNMatrixPtr& out) {
MKLDNNLayer::resetOutGrad(out, outVal_->getPrimitiveDesc());
} else {
const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad;
// always share the same grad data of CPU output
// then the activation can get the right grad from output_.grad
output_.grad->setData(cpuOut->getData());
cpuOutGrad_ = MKLDNNMatrix::create(
cpuOut, memory::dims{bs_, oc_, oh_, ow_}, format::nchw, engine_);
if (cpuOutGrad_->getPrimitiveDesc() != outVal_->getPrimitiveDesc()) {
out = MKLDNNMatrix::create(output_.grad, outVal_->getPrimitiveDesc());
out = MKLDNNMatrix::create(nullptr, outVal_->getPrimitiveDesc());
cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out);
CHECK(cvtOutGrad_) << "should not be emptry";
} else {
// share the same data of CPU output
output_.grad->setData(cpuOut->getData());
out = cpuOutGrad_;
}
}
......
......@@ -26,7 +26,10 @@ if(WITH_MKLDNN)
test_MKLDNN.cpp
MKLDNNTester.cpp
LayerGradUtil.cpp)
add_test(NAME test_MKLDNN COMMAND test_MKLDNN)
add_test(NAME test_MKLDNN
COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python
${CMAKE_CURRENT_BINARY_DIR}/test_MKLDNN
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle)
endif()
################ test_CRFLayerGrad ####################
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "MKLDNNTester.h"
#include "paddle/gserver/layers/MKLDNNBase.h"
#include "paddle/gserver/layers/MKLDNNLayer.h"
#include "paddle/trainer/Trainer.h"
namespace paddle {
......@@ -315,6 +316,7 @@ void MKLDNNTester::runOnce() {
auto& value = para->getBuf(PARAMETER_VALUE);
real lr = 1e-3;
value->add(*grad, lr);
grad->zeroMem();
};
randomTopDiffs();
dnnLayer_->backward(updateCallback);
......@@ -411,4 +413,143 @@ void MKLDNNTester::run(const TestConfig& dnn,
}
}
void MKLDNNTester::initArgument(DataIn& data,
const std::string& configPath,
const size_t iter) {
TrainerConfigHelper config(configPath);
size_t batchSize = config.getOptConfig().batch_size();
data.inArgs.resize(iter);
data.outGrads.resize(iter);
data.paraValues.clear();
for (const auto& layer_name : config.getModelConfig().input_layer_names()) {
auto layer_config = std::find_if(config.getModelConfig().layers().begin(),
config.getModelConfig().layers().end(),
[=](const LayerConfig& layer_config) {
return layer_config.name() == layer_name;
});
CHECK(layer_config != config.getModelConfig().layers().end());
size_t layerSize = layer_config->size();
for (size_t i = 0; i < iter; ++i) {
Argument arg;
arg.value = Matrix::create(batchSize, layerSize, false, false);
arg.grad = Matrix::create(batchSize, layerSize, false, false);
arg.value->randomizeUniform();
arg.value->add(-0.5);
arg.value->sigmoid(*arg.value);
arg.grad->zeroMem();
arg.ids = VectorT<int>::create(batchSize, false);
arg.ids->rand(layerSize);
generateSequenceStartPositions(batchSize, arg.sequenceStartPositions);
data.inArgs[i].push_back(arg);
}
}
for (const auto& layer_name : config.getModelConfig().output_layer_names()) {
auto layer_config = std::find_if(config.getModelConfig().layers().begin(),
config.getModelConfig().layers().end(),
[=](const LayerConfig& layer_config) {
return layer_config.name() == layer_name;
});
CHECK(layer_config != config.getModelConfig().layers().end());
size_t layerSize = layer_config->size();
for (size_t i = 0; i < iter; ++i) {
MatrixPtr grad = Matrix::create(batchSize, layerSize, false, false);
grad->randomizeUniform();
data.outGrads[i].push_back(grad);
}
}
for (const auto& para_config : config.getModelConfig().parameters()) {
VectorPtr value = Vector::create(para_config.size(), false);
value->randnorm(0, 2);
data.paraValues.push_back(value);
}
}
void MKLDNNTester::getOutResult(const std::string& configPath,
DataIn& in,
DataOut& out,
bool use_mkldnn,
size_t iter) {
FLAGS_use_gpu = false;
FLAGS_use_mkldnn = use_mkldnn;
*ThreadLocalRand::getSeed() = 1;
srand(1);
Trainer trainer;
auto config = std::make_shared<TrainerConfigHelper>(configPath);
trainer.init(config, false);
auto gradientMachine = trainer.getGradientMachine();
std::vector<ParameterPtr> parameters = gradientMachine->getParameters();
for (size_t i = 0; i < in.paraValues.size(); i++) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
}
UpdateCallback simpleUpdate = [](Parameter* para) {
auto& grad = para->getBuf(PARAMETER_GRADIENT);
auto& value = para->getBuf(PARAMETER_VALUE);
real lr = 1e-2;
value->add(*grad, lr);
grad->zeroMem();
};
vector<Argument> outArgs;
gradientMachine->start();
out.outValues.clear();
out.paraValues.clear();
for (size_t i = 0; i < iter; ++i) {
VLOG(MKLDNN_TESTS) << "runing iteration " << i;
gradientMachine->forward(in.inArgs[i], &outArgs, PASS_TRAIN);
// save forward result
for (size_t k = 0; k < outArgs.size(); k++) {
MatrixPtr value = Matrix::create(outArgs[k].value->getHeight(),
outArgs[k].value->getWidth(),
false,
false);
value->copyFrom(*outArgs[k].value);
out.outValues.push_back(value);
}
// random backward input
for (size_t k = 0; k < outArgs.size(); k++) {
outArgs[k].grad->copyFrom(*in.outGrads[i][k]);
}
gradientMachine->backward(simpleUpdate);
}
gradientMachine->finish();
// save param value
for (size_t i = 0; i < in.paraValues.size(); i++) {
VectorPtr val = Vector::create(
parameters[i]->getBuf(PARAMETER_VALUE)->getSize(), false);
val->copyFrom(*parameters[i]->getBuf(PARAMETER_VALUE));
out.paraValues.push_back(val);
}
}
void MKLDNNTester::compareResult(DataOut& ref, DataOut& dnn, float eps) {
CHECK_EQ(ref.outValues.size(), dnn.outValues.size());
CHECK_EQ(ref.paraValues.size(), dnn.paraValues.size());
for (size_t i = 0; i < ref.outValues.size(); i++) {
EXPECT_LE(fabs(compareMatrix(ref.outValues[i], dnn.outValues[i])), eps);
}
for (size_t i = 0; i < ref.paraValues.size(); i++) {
EXPECT_LE(fabs(compareVector(ref.paraValues[i], dnn.paraValues[i])), eps);
}
}
void MKLDNNTester::runBranchesTest(const std::string& configPath,
size_t iter,
float eps) {
DataIn in;
initArgument(in, configPath, iter);
DataOut outCpu, outDnn;
getOutResult(configPath, in, outCpu, false, iter);
getOutResult(configPath, in, outDnn, true, iter);
compareResult(outCpu, outDnn, eps);
}
} // namespace paddle
......@@ -33,6 +33,17 @@ class MKLDNNTester {
NUM = 2, // Number of total
};
struct DataIn {
std::vector<std::vector<Argument>> inArgs;
std::vector<std::vector<MatrixPtr>> outGrads;
std::vector<VectorPtr> paraValues;
};
struct DataOut {
std::vector<MatrixPtr> outValues;
std::vector<VectorPtr> paraValues;
};
protected:
std::vector<TestConfig> configs_;
vector<string> layerNames_;
......@@ -74,7 +85,17 @@ public:
float epsilon = 1e-4,
bool log = false,
int level = MKLDNN_ALL);
void setLogLevel(int lvl) { lvl_ = lvl; }
static void runBranchesTest(const std::string& configPath,
size_t iter = 3,
float eps = 1e-4);
static void initArgument(DataIn& data,
const std::string& configPath,
size_t iter = 3);
static void getOutResult(const std::string& configPath,
DataIn& in,
DataOut& out,
bool use_mkldnn,
size_t iter = 3);
private:
void reset(const TestConfig& dnn, const TestConfig& ref, size_t batchSize);
......@@ -101,8 +122,9 @@ private:
void saveWgt(const vector<ParameterPtr>& from, vector<VectorPtr>& to);
void restoreWgt(const vector<VectorPtr>& from, vector<ParameterPtr>& to);
double compareMatrix(const MatrixPtr& m1, const MatrixPtr& m2);
double compareVector(const VectorPtr& v1, const VectorPtr& v2);
static double compareMatrix(const MatrixPtr& m1, const MatrixPtr& m2);
static double compareVector(const VectorPtr& v1, const VectorPtr& v2);
static void compareResult(DataOut& ref, DataOut& dnn, float eps = 1e-4);
/**
* Get delta percent
......@@ -111,11 +133,11 @@ private:
* else return sum(abs(a-b)) / sum(abs(b))
* The return value should be smaller than eps when passing.
*/
double getDelta(const real* d1,
const real* d2,
size_t len,
const float failRate = 1e-3,
const float thres = 0.1);
static double getDelta(const real* d1,
const real* d2,
size_t len,
const float failRate = 1e-3,
const float thres = 0.1);
};
} // namespace paddle
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
settings(batch_size=16)
channels = get_config_arg("channels", int, 2)
def two_conv(input, group_name):
out1 = img_conv_layer(input=input,
name=group_name+'_conv1',
filter_size=1,
num_filters=channels,
padding=0,
shared_biases=True,
act=ReluActivation())
out2 = img_conv_layer(input=input,
name=group_name+'_conv2',
filter_size=3,
num_filters=channels,
padding=1,
shared_biases=True,
act=ReluActivation())
return out1, out2
data = data_layer(name ="input", size=channels*16*16)
conv = img_conv_layer(input=data,
num_channels=channels,
filter_size=3,
num_filters=channels,
padding=1,
shared_biases=True,
act=ReluActivation())
a1, a2 = two_conv(input=conv, group_name='a')
concat = concat_layer(input=[a1, a2])
b1, b2 = two_conv(input=conv, group_name='b')
addto = addto_layer(input=[b1, b2])
outputs([concat, addto])
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/utils/PythonUtil.h>
#include <string>
#include <vector>
#include "MKLDNNTester.h"
......@@ -40,12 +41,13 @@ DECLARE_bool(use_mkldnn);
struct testFcDesc {
int bs;
int ic;
int oc;
int ih, iw; // oh == ow == 1
int oc;
};
static void getMKLDNNFcConfig(TestConfig& cfg, const testFcDesc& pm) {
cfg.layerConfig.set_type("mkldnn_fc");
cfg.layerConfig.set_active_type("relu");
cfg.layerConfig.set_size(pm.oc);
cfg.inputDefs.push_back(
{INPUT_DATA,
......@@ -86,6 +88,7 @@ struct testConvDesc {
static void getMKLDNNConvConfig(TestConfig& cfg, const testConvDesc& pm) {
cfg.layerConfig.set_type("mkldnn_conv");
cfg.layerConfig.set_active_type("relu");
cfg.layerConfig.set_num_filters(pm.oc);
cfg.layerConfig.set_size(pm.oc * pm.oh * pm.ow);
cfg.layerConfig.set_shared_biases(true);
......@@ -158,6 +161,7 @@ struct testPoolDesc {
static void getMKLDNNPoolConfig(TestConfig& cfg, const testPoolDesc& pm) {
cfg.layerConfig.set_type("mkldnn_pool");
cfg.layerConfig.set_active_type("relu");
cfg.layerConfig.set_size(pm.ic * pm.oh * pm.ow);
cfg.inputDefs.push_back(
{INPUT_DATA,
......@@ -244,13 +248,26 @@ TEST(MKLDNNActivation, Activations) {
}
}
// TODO(TJ): add branch test
DECLARE_string(config_args);
TEST(MKLDNNLayer, branches) {
std::vector<std::string> cases = {"conv"};
for (auto name : cases) {
std::string config = "./gserver/tests/mkldnn_branches_" + name + ".conf";
for (auto channels : {2, 32}) {
std::ostringstream oss;
oss << "channels=" << channels;
FLAGS_config_args = oss.str();
MKLDNNTester::runBranchesTest(config);
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
FLAGS_use_gpu = false;
FLAGS_use_mkldnn = true;
initMain(argc, argv);
initPython(argc, argv);
FLAGS_thread_local_rand_use_global_seed = true;
srand(1);
return RUN_ALL_TESTS();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册