提交 0190d5d6 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into check_out_lod

...@@ -2,7 +2,7 @@ INCLUDE(ExternalProject) ...@@ -2,7 +2,7 @@ INCLUDE(ExternalProject)
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl) SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl) INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO) if(WITH_DSO)
......
...@@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0)); input_batch, false, T(1.0), &col_matrix, T(0.0));
col2im(context.device_context(), output_batch, col, strides[0], col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0); strides[1], 0, 0, 0, 0);
} }
} }
}; };
...@@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: dx = filter * dy // gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
...@@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// im2col: (c * h * w, k_h * k_w) // im2col: (c * h * w, k_h * k_w)
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: d_filter = x * y_grad^T // gemm: d_filter = x * y_grad^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
......
...@@ -326,6 +326,17 @@ class Parameters(object): ...@@ -326,6 +326,17 @@ class Parameters(object):
self.set(name, arr.reshape(self.get_shape(name))) self.set(name, arr.reshape(self.get_shape(name)))
def to_tar(self, f): def to_tar(self, f):
"""
Save parameters to a tar file.
WARNING: You should use `paddle.v2.trainer.SGD.save_parameter_to_tar(f)`
to save parameters most of the time. Otherwise, some settings such
as model average will not take effect.
:param f:
:type f: file
:return:
"""
tar = tarfile.TarFile(fileobj=f, mode='w') tar = tarfile.TarFile(fileobj=f, mode='w')
for nm in self.names(): for nm in self.names():
buf = cStringIO.StringIO() buf = cStringIO.StringIO()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册