提交 bcc0dad7 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into lstm_bp

......@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config;
paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
!modelConfig.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
}
} else {
modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create(
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
for (auto& para : parameters) {
para->load(is);
......
......@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`.
namespace ops = paddle::operators;
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
ops::CrossEntropyOpKernel<double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<float>);
ops::CrossEntropyGradientOpKernel<float>,
ops::CrossEntropyGradientOpKernel<double>);
......@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>);
ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>);
......@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
};
template class CrossEntropyFunctor<platform::CPUPlace, float>;
template class CrossEntropyFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -39,11 +39,36 @@ __device__ __forceinline__ T sum_single_warp(T val) {
return val;
}
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ T* GetPointer() { return NULL; }
};
template <>
struct SharedMemory<float> {
__device__ float* GetPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ double* GetPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
SharedMemory<T> d_sum_shared;
T* d_sum = d_sum_shared.GetPointer();
d_sum[tid] = 0;
int cur_idx = tid;
......@@ -102,6 +127,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
};
template class CrossEntropyFunctor<platform::GPUPlace, float>;
template class CrossEntropyFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/utils/PythonUtil.h"
DEFINE_string(model_dir, "", "Directory for separated model files");
DEFINE_string(config_file, "", "Config file for the model");
DEFINE_string(model_file, "", "File for merged model file");
using namespace paddle; // NOLINT
......@@ -28,7 +29,8 @@ using namespace std; // NOLINT
int main(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);
string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir);
string confFile = FLAGS_config_file;
#ifndef PADDLE_WITH_CUDA
FLAGS_use_gpu = false;
#endif
......
......@@ -19,7 +19,7 @@ import "ModelConfig.proto";
package paddle;
message OptimizationConfig {
required int32 batch_size = 3;
optional int32 batch_size = 3 [ default = 1 ];
required string algorithm = 4 [ default = "async_sgd" ];
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];
......
......@@ -8,6 +8,15 @@ from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
def randomize_probability(batch_size, class_num, dtype='float32'):
prob = np.random.uniform(
0.1, 1.0, size=(batch_size, class_num)).astype(dtype)
prob_sum = prob.sum(axis=1)
for i in xrange(len(prob)):
prob[i] /= prob_sum[i]
return prob
def grad_var_name(var_name):
return var_name + "@GRAD"
......@@ -233,7 +242,7 @@ def append_input_output(block, op_proto, np_list, is_input):
if (var_name not in np_list) and var_proto.dispensable:
continue
assert (var_name in np_list) or (var_proto.dispensable), \
"Missing {} as input".format(var_name)
"Missing {} as input".format(var_name)
if var_proto.duplicable:
assert isinstance(np_list[var_name], list), \
"Duplicable {} should be set as list".format(var_name)
......@@ -379,9 +388,9 @@ class OpTest(unittest.TestCase):
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return ("%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d") % (
"the first error element is %d, %f, %f") % (
msg_prefix, name, max_diff, max_relative_error,
offset)
offset, a.flatten()[offset], b.flatten()[offset])
self.assertLessEqual(max_diff, max_relative_error, err_msg())
......@@ -389,6 +398,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None):
......@@ -411,6 +421,7 @@ class OpTest(unittest.TestCase):
self.inputs,
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check
]
grad_names = [
......
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, randomize_probability
class TestCrossEntropyOp1(OpTest):
......@@ -12,12 +12,12 @@ class TestCrossEntropyOp1(OpTest):
batch_size = 30
class_num = 10
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
X = randomize_probability(batch_size, class_num, dtype='float64')
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float32")
dtype="float64")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
......@@ -27,7 +27,7 @@ class TestCrossEntropyOp1(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y")
self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
class TestCrossEntropyOp2(OpTest):
......@@ -39,8 +39,7 @@ class TestCrossEntropyOp2(OpTest):
batch_size = 5
class_num = 37
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
X = randomize_probability(batch_size, class_num)
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
......@@ -55,7 +54,8 @@ class TestCrossEntropyOp2(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp3(OpTest):
......@@ -67,8 +67,7 @@ class TestCrossEntropyOp3(OpTest):
batch_size = 5
class_num = 17
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
X = randomize_probability(batch_size, class_num)
label_index = np.random.randint(
0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape)
......@@ -88,7 +87,8 @@ class TestCrossEntropyOp3(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册