diff --git a/CMakeLists.txt b/CMakeLists.txt index dcd1218a5b0b62f2739b727391aca31b48ed9ccb..06dd5a1332cb9dda8f942a342693ec74fb6184d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -137,9 +137,9 @@ set(EXTERNAL_LIBS ) if(WITH_GPU) - list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) if(NOT WITH_DSO) - list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) endif(NOT WITH_DSO) endif(WITH_GPU) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 1003b1ccd83c85c54f07c8d2b84f993203408941..2c5ec76dfeb8b8485951e4d94896b6758e0cb930 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").AsNoGradient(); - AddInput("b", "Bias of Add").AsNoGradient(); - AddOutput("Out", "Out of Add").AsNoGradient(); + AddInput("X", "Input X of Add").NotInGradient(); + AddInput("b", "Bias of Add").NotInGradient(); + AddOutput("Out", "Out of Add").NotInGradient(); AddComment("Add Op"); } }; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 7077e8aa2c77c24efdbb34ed3a13821fe7678455..ae44a1ffd45dacdc44a72edc630e771e7a2f2990 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -60,7 +60,7 @@ message OpProto { optional bool duplicable = 3 [ default = false ]; optional bool intermediate = 4 [ default = false ]; - optional bool no_gradient = 5 [ default = false ]; + optional bool not_in_gradient = 5 [ default = false ]; } // AttrProto describes the C++ type Attribute. diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index b73dac22d029876de9a012de533647db3dd74cbb..0a2a41f6b62658ac8633a6e384d099f8d6641f33 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, const auto& src_arg_list = src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { - if (arg.no_gradient() && !is_grad) continue; + if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; dst_inout[dst_name].reserve(src_inout.at(src_name).size()); diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 0c26293fd29d24a7a40c47bdf055d2758846709b..902c2655e9182d74a48ad13e17a39a3304d5fa57 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); + AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); AddInput("In3_mult", "another multiple input").AsDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable(); - AddOutput("Out2", "a single output").AsNoGradient(); + AddOutput("Out2", "a single output").NotInGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 848baeeeb6493f61c41193a5cc0fc69e93934bfb..807298088981b969622174be753ea0da72067243 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -184,11 +184,8 @@ class OpProtoAndCheckerMaker { return *this; } - // TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it - // means that input/output is not needed when calculate gradient. It does - // not mean no gradient when backward. It should be changed soon. - VariableBuilder& AsNoGradient() { - var_->set_no_gradient(true); + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); return *this; } }; diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index 30f567eaf8248a8fba1b461a2bdbf2aab13f9e08..d201fac65e0459050304195140e1aae081468f43 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -57,11 +57,14 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, } void MKLDNNFcLayer::convertWeightsFromPaddle() { - if (FLAGS_use_mkldnn_wgt) { + if (hasInitedWgt_) { return; } - if (hasInitedWgt_) { + // TODO(TJ): dst format should get from wgtVal_ + int dstFmt = PARAM_FORMAT_MKLDNN_OI; + int srcFmt = weight_->getParameterPtr()->getHeaderFormat(); + if (srcFmt == dstFmt) { return; } @@ -78,6 +81,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { MatrixPtr paddleWgtT; paddleWgt->transpose(paddleWgtT, true); weight_->getW()->copyFrom(*paddleWgtT); + weight_->getParameterPtr()->setHeaderFormat(dstFmt); hasInitedWgt_ = true; } diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index 99c8c4948c9b05ad15d1217ebb70026bbd48453f..de1635be2af37cd0ba49010199a417090865b0e4 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -330,9 +330,7 @@ void MKLDNNTester::run(const TestConfig& dnn, log_ = log; lvl_ = level; - // Firstly test FLAGS_use_mkldnn_wgt = false - FLAGS_use_mkldnn_wgt = false; - // reset and run once + // Firstly test mkldnn init from PARAM_FORMAT_ORIGINAL weight reset(dnn, ref, batchSize); randomWgtDatas(); clearWgtDiffs(); @@ -342,17 +340,32 @@ void MKLDNNTester::run(const TestConfig& dnn, runOnce(); } - // Then test FLAGS_use_mkldnn_wgt = true - FLAGS_use_mkldnn_wgt = true; - // after run once the mkldnn weight has been stored in dnnlayer + if (parameters_[DNN].empty()) { + // has no paramters + return; + } + + // After run some iterations, the mkldnn weight has been stored in dnnLayer + // and we can also get the mkldnn weight parameter header format. + // Weight parameter should always be index 0 (and bias index 1). + // TODO(TJ): should also consider mean and var format when batchnorm ready + int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat(); + int refWgtFmt = parameters_[REF][0]->getHeaderFormat(); + if (dnnWgtFmt == refWgtFmt) { + // weight format are equal, so no need check more + return; + } + // then save the weights and restart again vector dnnWgts, refWgts; CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); saveWgt(parameters_[DNN], dnnWgts); saveWgt(parameters_[REF], refWgts); - // restart again with flag true + // restart again with dnn weight format reset(dnn, ref, batchSize); + // TODO(TJ): should also considerate mean and var format when batchnorm ready + parameters_[DNN][0]->setHeaderFormat(dnnWgtFmt); // restore wgt restoreWgt(dnnWgts, parameters_[DNN]); diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h index 522eeaf24b1949abac057a1e59e9977610be23c0..e55e4493ffdfe45b8cfdee423febd1878b8b3d8a 100644 --- a/paddle/gserver/tests/MKLDNNTester.h +++ b/paddle/gserver/tests/MKLDNNTester.h @@ -108,7 +108,7 @@ private: * if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the * max(diff/ref) * else return sum(abs(a-b)) / sum(abs(b)) - * The return value should smaller than eps when passing. + * The return value should be smaller than eps when passing. */ double getDelta(const real* d1, const real* d2, diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 49d0f43508b1ee3df0c6b5987942970e1649e310..d3d0e55a674587fb04f43f24d0790de4358f035a 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").AsNoGradient(); + AddOutput("Out", "The output of mean op").NotInGradient(); AddComment("Mean Operator"); } }; diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index fcb703e63bd5a82f9ffac2bf64e06fd0218dbdaa..9848af280b62729bef9243052ceae0b7d8f4c6f5 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = (T)framework::product(IG->dims()); + Eigen::DSizes bcast(ig_size); EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = - EigenScalar::From(*OG) / ig_size; + (EigenVector::From(*OG) / ig_size).broadcast(bcast); } }; diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index d773a4f2d50e82146a729b1cda085ce86ade89cc..761c6de8d4d2150b30b97b58da95da3d5f33db63 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -44,7 +44,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("Y")->dims()); } }; diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 11ab923eb346c1f8de3a6bbebdfa874b6530004a..b01a9b3f23283471f8846325075719ba0e75ed35 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -37,7 +37,7 @@ class SigmoidKernel : public framework::OpKernel { auto Y = EigenVector::Flatten(*output); auto place = context.GetEigenDevice(); - Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); + Y.device(place) = 1. / (1. + (-X).exp()); } }; diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index ebe36d49376882fe4c1013e19dcf71f452b3e501..f0311095012d944768d80abe423d4a9bfc0e97f5 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -48,7 +48,8 @@ Parameter::Parameter(const ParameterConfig& config, bool useGpu, bool doInit) deviceId_(-1), sharedCount_(0), updateCounter_(0), - updated_(false) { + updated_(false), + headerFormat_(PARAM_FORMAT_ORIGINAL) { setID(-1); /* capture uninitialized id */ if (useGpu_ && FLAGS_parallel_nn) { /* gpu environment is specified by device property */ @@ -285,7 +286,7 @@ bool Parameter::save(const std::string& filename) const { bool Parameter::save(std::ostream& s) const { CpuVector vec(*bufs_[PARAMETER_VALUE].get()); Header header; - header.version = kFormatVersion; + header.format = headerFormat_; header.valueSize = sizeof(real); header.size = getSize(); @@ -344,8 +345,9 @@ bool Parameter::load(std::istream& s) { Header header; CHECK(s.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameter " << getName(); - CHECK_EQ(header.version, kFormatVersion) << "Incorrect format version: " - << header.version; + CHECK(isHeaderFormatSupported(header.format)) << "Incorrect format version: " + << header.format; + headerFormat_ = header.format; CHECK_EQ(header.size, getSize()) << "The size (" << header.size << ") in the file does not match the size " << "(" << getSize() << ") of the parameter: " << getName(); diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 0bac76f068ec22bec52766b43e331fe109a34188..e31cbc3dee6c57851c241e117dbbd9b701db9d2c 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -34,6 +34,20 @@ limitations under the License. */ namespace paddle { +typedef enum { + /// The paddle original basic format + PARAM_FORMAT_ORIGINAL = 0, + + /// See mkldnn_memory_format_t in + /// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h + /// for a detailed description. + /// 2D weights tensor in the format (output channels, input channels). + PARAM_FORMAT_MKLDNN_OI, + + /// The total format items numbers + PARAM_FORMAT_ITEMS, +} PARAM_FORMAT; + class SparsePrefetchRowCpuMatrix; class Parameter; @@ -242,14 +256,30 @@ public: /// Initialize the value to 0 void zeroMem(); - static const int kFormatVersion = 0; /// file header structure struct Header { - int32_t version; // = 0, file format version + int32_t format; // = PARAM_FORMAT uint32_t valueSize; // = sizeof(real) uint64_t size; // = getSize() }; + /** + * @brief Is the header format supported. + */ + static bool isHeaderFormatSupported(int32_t fmt) { + return fmt < PARAM_FORMAT_ITEMS; + } + + /** + * @brief Get the format in header. + */ + int getHeaderFormat() { return headerFormat_; } + + /** + * @brief Set the format in header. + */ + void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } + /** * @brief Parameter Update Hook. * @@ -321,6 +351,9 @@ protected: bool updated_; SparseFormat format_; + /// The header format for saving or loading param + int32_t headerFormat_; + std::vector> updaterHooks_; public: diff --git a/paddle/pserver/ParameterServer2.cpp b/paddle/pserver/ParameterServer2.cpp index d7c1d4f788f44c6bfcec040ba24bdc454348c911..54f5c4c0fb4994871edc7a1e52237c9f903ce63b 100644 --- a/paddle/pserver/ParameterServer2.cpp +++ b/paddle/pserver/ParameterServer2.cpp @@ -1032,8 +1032,8 @@ void ParameterServer2::loadValueVector(const LoadValueRequest& request, Parameter::Header header; CHECK(fs.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameters in pserver"; - CHECK_EQ(header.version, Parameter::kFormatVersion) - << "Incorrect format version: " << header.version; + CHECK(Parameter::isHeaderFormatSupported(header.format)) + << "Incorrect format version: " << header.format; CHECK_EQ(header.size, (size_t)size_) << "The size (" << header.size << ") in the file does not match the size " << "(" << size_ << ") of the pserver: " << serverId_; @@ -1063,7 +1063,8 @@ void ParameterServer2::saveValueVector(const SaveValueRequest& request, CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY] : *vectors_[PARAMETER_VALUE]; Parameter::Header header; - header.version = Parameter::kFormatVersion; + // TODO(TJ): save param headerFormat_ + header.format = PARAM_FORMAT_ORIGINAL; header.valueSize = sizeof(real); header.size = size_; diff --git a/paddle/trainer/TrainerConfigHelper.cpp b/paddle/trainer/TrainerConfigHelper.cpp index eba40862b926cfe863c569e73a6a3ceabcf1f3b4..a0a365aa0bb0ac26939a02c1cd626d0c17c6a9fe 100644 --- a/paddle/trainer/TrainerConfigHelper.cpp +++ b/paddle/trainer/TrainerConfigHelper.cpp @@ -29,7 +29,6 @@ DECLARE_bool(with_gpu); DECLARE_bool(parallel_nn); DECLARE_string(config_args); DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); const char *kConfigParserModuleName = "paddle.trainer.config_parser"; const char *kConfigParserFuncName = "parse_config_and_serialize"; @@ -47,7 +46,6 @@ TrainerConfigHelper::TrainerConfigHelper(const std::string &configFilePath) << ",with_cost=" << FLAGS_with_cost << ",use_gpu=" << FLAGS_use_gpu << ",parallel_nn=" << FLAGS_parallel_nn << ",use_mkldnn=" << FLAGS_use_mkldnn - << ",use_mkldnn_wgt=" << FLAGS_use_mkldnn_wgt << ",cudnn_version=" << hl_get_cudnn_lib_version(); if (!FLAGS_config_args.empty()) { configArgs << "," << FLAGS_config_args; diff --git a/paddle/utils/Flags.cpp b/paddle/utils/Flags.cpp index 600c83a8487191895de635dd8433f6c44e86ce77..ab1c181c62cdbee8cc5f804ec9aaf63ac5464ad6 100644 --- a/paddle/utils/Flags.cpp +++ b/paddle/utils/Flags.cpp @@ -27,7 +27,6 @@ DEFINE_bool(use_mkldnn, false, "Default still keep use CPU training"); DEFINE_bool(use_mkldnn, false, "Only support CPU training"); #endif -DEFINE_bool(use_mkldnn_wgt, false, "Init weight from CPU weight"); DEFINE_bool(parallel_nn, false, "Whether to use multi-threads to calculate one neural network." diff --git a/paddle/utils/Flags.h b/paddle/utils/Flags.h index 0aca4c0ee036ee8490c0ceca7279df876dc21947..1832bb515ec85df3d7733e01b063a01ad6a3b282 100644 --- a/paddle/utils/Flags.h +++ b/paddle/utils/Flags.h @@ -41,4 +41,3 @@ DECLARE_string(predict_file); DECLARE_bool(prev_batch_state); DECLARE_string(init_model_path); DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index faeac69513a80b26b0200e1aa516d4cb33459851..ce57a0713092723b6a99b2416e06ff1a436f043b 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -26,3 +26,4 @@ py_test(test_operator SRCS test_operator.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) +py_test(test_gradient_checker SRCS test_gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 831c0f0f2a9096cd78ae0a0ffd9b515110c6fa06..8b8e2f444be1169c23784321721c5d8154541fcf 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -1,6 +1,7 @@ import unittest import numpy +import itertools import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator @@ -8,6 +9,7 @@ __all__ = ['get_numeric_gradient'] def create_op(op_type): + # TODO need to set attrs kwargs = dict() for in_name in Operator.get_op_input_names(op_type): kwargs[in_name] = in_name @@ -66,7 +68,6 @@ def get_numeric_gradient(op, local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace( )) - # TODO(yuyang18): Only CPU is support now. cpu_ctx = core.DeviceContext.create(core.CPUPlace()) def get_output(): @@ -109,12 +110,110 @@ def get_numeric_gradient(op, class GradientChecker(unittest.TestCase): - def assert_is_close(self, numeric_grads, scope, max_relative_error, - msg_prefix): - for name in numeric_grads: - b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor()) - a = numeric_grads[name] + def __get_gradient(self, forward_op, backward_op, input_value, grad_names, + place): + """Get the input gradients after running forward and backward operators + on the given places. + + :param forward_op: forward operator + :type forward_op: Operator + :param backward_op: backward operator + :type backward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :param grad_names: the names of returned input gradients. + :type input_value: a list of string + :param place: the device type. + :type place: CPUPlace or GPUPlace + :return: the input grdients of given grad_names. + :rtype: a list of numpy.array + """ + scope = core.Scope() + ctx = core.DeviceContext.create(place) + + inputs = forward_op.inputs() + in_names = [item for k in inputs for item in inputs[k]] + outputs = forward_op.outputs() + out_names = [item for k in outputs for item in outputs[k]] + + # create input var and set value + for name, value in input_value.iteritems(): + if name not in in_names: + raise ValueError(name + "does not exist in Op's inputs.") + var = scope.new_var(name).get_tensor() + var.set_dims(value.shape) + var.set(value, place) + + # run forward op + for out_name in out_names: + scope.new_var(out_name) + forward_op.infer_shape(scope) + forward_op.run(scope, ctx) + + # set output var's shape + # set output grad to ones + for name in out_names: + out_tensor = scope.find_var(name).get_tensor() + grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() + grad_tensor.set_dims(out_tensor.shape()) + data = numpy.ones(out_tensor.shape(), dtype=numpy.float32) + grad_tensor.set(data, place) + + # run backward op + for name in backward_op.outputs(): + scope.new_var(name) + backward_op.infer_shape(scope) + backward_op.run(scope, ctx) + + outs = [ + numpy.array(scope.find_var(name).get_tensor()) + for name in grad_names + ] + return outs + + def compare_grad(self, forward_op, input_value): + """ Compare the input gradients between CPU and GPU for the given forward + operator. + + :param forward_op: forward operator + :type forward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :raises: AssertionError, there is different gradient value. + """ + backward_op = core.Operator.backward(forward_op, set()) + # return if not compile with GPU or not implementing GPU kernel + if not (core.is_compile_gpu() and backward_op.support_gpu()): + return + outputs = backward_op.outputs() + out_names = [item for k in outputs for item in outputs[k]] + cpu_grads = self.__get_gradient(forward_op, backward_op, input_value, + out_names, core.CPUPlace()) + gpu_grads = self.__get_gradient(forward_op, backward_op, input_value, + out_names, core.GPUPlace(0)) + + for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads, + out_names): + self.assertTrue( + numpy.allclose( + c_grad, g_grad, atol=1e-4), + "output name: " + name + " has diff") + + def __assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): + """Use relative error for the comparison. + + :param numeric_grads: the numerical graidents. + :type numeric_grads: a list of numpy.array + :param analytic_grads: the analytical graidents. + :type analytic_grads: a list of numpy.array + :param name: the names of gradients, used to print for debug. + :type names: a list of string + :param msg_prefix: string info, used to print for debug. + :type msf_prefix: string + """ + for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): abs_a = numpy.abs(a) # if abs_a is nearly zero, then use abs error for a, not relative # error. @@ -159,105 +258,26 @@ class GradientChecker(unittest.TestCase): inputs = forward_op.inputs() in_names = [item for k in inputs for item in inputs[k]] - outputs = forward_op.outputs() - out_names = [item for k in outputs for item in outputs[k]] - for no_grad in no_grad_set: if no_grad not in in_names: raise ValueError("no_grad should be in in_names") backward_op = core.Operator.backward(forward_op, no_grad_set) - bwd_outputs = backward_op.outputs() - bwd_out_names = [item for k in bwd_outputs for item in bwd_outputs[k]] - places = [core.CPUPlace()] if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu(): places.append(core.GPUPlace(0)) - numeric_grad = dict() - # get numeric gradient - for check_name in inputs_to_check: - numeric_grad[check_name] = \ - get_numeric_gradient(forward_op, input_vars, output_name, - check_name) + # get numerical gradients + numeric_grads = [ + get_numeric_gradient(forward_op, input_vars, output_name, name) + for name in inputs_to_check + ] - # get operator gradient according to different device + check_names = [grad_var_name(name) for name in inputs_to_check] for place in places: - scope = core.Scope() - ctx = core.DeviceContext.create(place) - - # create input var and set value - for name, value in input_vars.iteritems(): - if name not in in_names: - raise ValueError(name + " not in op.inputs_") - var = scope.new_var(name).get_tensor() - var.set_dims(value.shape) - var.set(value, place) - - # create output var - for out_name in out_names: - scope.new_var(out_name).get_tensor() - - # infer the shape of output var and compute/set value of output var - forward_op.infer_shape(scope) - forward_op.run(scope, ctx) - - # create output grad var - # set shape as the output var - # set value of this grad to ones - for name in out_names: - out_tensor = scope.find_var(name).get_tensor() - grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() - grad_tensor.set_dims(out_tensor.shape()) - data = 1.0 * numpy.ones(out_tensor.shape()) - grad_tensor.set(data, place) - - # create input grad var - for name in bwd_out_names: - scope.new_var(name).get_tensor() - - # infer the shape of input gradient var and compute/set it's value - # with backward op - backward_op.infer_shape(scope) - backward_op.run(scope, ctx) - - self.assert_is_close(numeric_grad, scope, max_relative_error, - "Gradient Check On %s" % str(place)) - - -if __name__ == '__main__': - - class GetNumericGradientTest(unittest.TestCase): - def test_add_op(self): - add_op = Operator('add_two', X="X", Y="Y", Out="Z") - x = numpy.random.random((10, 1)).astype("float32") - y = numpy.random.random((10, 1)).astype("float32") - - arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') - self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) - - def test_softmax_op(self): - def stable_softmax(x): - """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - numpy.max(x) - exps = numpy.exp(shiftx) - return exps / numpy.sum(exps) - - def label_softmax_grad(Y, dY): - dX = Y * 0.0 - for i in range(Y.shape[0]): - d = numpy.dot(Y[i, :], dY[i, :]) - dX[i, :] = Y[i, :] * (dY[i, :] - d) - return dX - - softmax_op = Operator("softmax", X="X", Y="Y") - - X = numpy.random.random((2, 2)).astype("float32") - Y = numpy.apply_along_axis(stable_softmax, 1, X) - dY = numpy.ones(Y.shape) - dX = label_softmax_grad(Y, dY) - - arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') - numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) - - unittest.main() + # get analytical gradients according to different device + analytic_grads = self.__get_gradient(forward_op, backward_op, + input_vars, check_names, place) + self.__assert_is_close(numeric_grads, analytic_grads, check_names, + max_relative_error, + "Gradient Check On %s" % str(place)) diff --git a/python/paddle/v2/framework/tests/test_gradient_checker.py b/python/paddle/v2/framework/tests/test_gradient_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b315120862bea284e067070492dcdfbb661081 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gradient_checker.py @@ -0,0 +1,43 @@ +import unittest +import numpy +from paddle.v2.framework.op import Operator +from gradient_checker import GradientChecker +from gradient_checker import get_numeric_gradient + + +class GetNumericGradientTest(unittest.TestCase): + def test_add_op(self): + add_op = Operator('add_two', X="X", Y="Y", Out="Z") + x = numpy.random.random((10, 1)).astype("float32") + y = numpy.random.random((10, 1)).astype("float32") + + arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') + self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4) + + def test_softmax_op(self): + def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - numpy.max(x) + exps = numpy.exp(shiftx) + return exps / numpy.sum(exps) + + def label_softmax_grad(Y, dY): + dX = Y * 0.0 + for i in range(Y.shape[0]): + d = numpy.dot(Y[i, :], dY[i, :]) + dX[i, :] = Y[i, :] * (dY[i, :] - d) + return dX + + softmax_op = Operator("softmax", X="X", Y="Y") + + X = numpy.random.random((2, 2)).astype("float32") + Y = numpy.apply_along_axis(stable_softmax, 1, X) + dY = numpy.ones(Y.shape) + dX = label_softmax_grad(Y, dY) + + arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') + numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py index b5d52b90567bcd0c9f376147145d8638049f7bab..f32b3160d651a290823223c46c45bb3b6950a505 100644 --- a/python/paddle/v2/framework/tests/test_mean_op.py +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -1,5 +1,6 @@ import unittest from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op import numpy as np @@ -12,5 +13,12 @@ class TestMeanOp(unittest.TestCase): self.outputs = {'Out': np.mean(self.inputs['X'])} +class MeanGradOpTest(GradientChecker): + def test_normal(self): + op = create_op("mean") + inputs = {"X": np.random.random((10, 10)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2a57a41ed8b718fd420062ba68e853a4861b7359..273c2e5ab1a84d12621fe9568c4cf22073b6aed4 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestSigmoidOp(unittest.TestCase): @@ -8,12 +9,20 @@ class TestSigmoidOp(unittest.TestCase): def setUp(self): self.type = "sigmoid" - self.inputs = {'X': np.random.random((32, 100)).astype("float32")} + self.inputs = {'X': np.random.random((15, 31)).astype("float32")} self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} -#class TestSigmoidGradOp(unittest.TestCase): -#TODO(qingqing) add unit test +class TestSigmoidGradOp(GradientChecker): + def test_grad(self): + op = create_op("sigmoid") + inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")} + # compare gpu and cpu results for backward op. + # this test will be skiped if only compiling CPU version. + self.compare_grad(op, inputs) + # check gradients + self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007) + if __name__ == '__main__': unittest.main()