提交 dca9941e 编写于 作者: T typhoonzero

pass size when copy

上级 67d6f3a8
...@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride, out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride); in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} }
...@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }
......
...@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }
......
...@@ -54,11 +54,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -54,11 +54,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst, int64_t axis, T* dst,
const framework::DDim& dst_stride_numel, const framework::DDim& dst_stride_numel,
const T* src, const T* src,
const framework::DDim& src_stride_numel) { const framework::DDim& src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis]; int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis];
int64_t copy_size = std::min(src_after, dst_after);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(), PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
...@@ -83,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -83,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place); auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place, memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * copy_size); src + i * src_after, sizeof(T) * size);
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place); auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx = auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx); reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place, memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * copy_size, src + i * src_after, sizeof(T) * size, cuda_ctx.stream());
cuda_ctx.stream());
#else #else
PADDLE_THROW("Paddle is not compiled with GPU"); PADDLE_THROW("Paddle is not compiled with GPU");
#endif #endif
......
...@@ -121,6 +121,7 @@ def split_dense_variable(var_list, ...@@ -121,6 +121,7 @@ def split_dense_variable(var_list,
block_size += dim1 - remains block_size += dim1 - remains
# update split_count after aligning # update split_count after aligning
split_count = int(math.ceil(var_numel / float(block_size))) split_count = int(math.ceil(var_numel / float(block_size)))
print("###split var ", var.name, var.shape, block_size, split_count)
for block_id in xrange(split_count): for block_id in xrange(split_count):
curr_block_size = min(block_size, var_numel - ( curr_block_size = min(block_size, var_numel - (
(block_id) * block_size)) (block_id) * block_size))
...@@ -255,6 +256,7 @@ class DistributeTranspiler: ...@@ -255,6 +256,7 @@ class DistributeTranspiler:
splited_shape = [rows] splited_shape = [rows]
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
print("###splited: ", size, rows, splited_shape)
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d" % (varname, i), name="%s.block%d" % (varname, i),
psersistable=False, psersistable=False,
...@@ -262,6 +264,7 @@ class DistributeTranspiler: ...@@ -262,6 +264,7 @@ class DistributeTranspiler:
type=orig_var.type, type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
print("###created split var ", var)
return var_mapping return var_mapping
def _clone_var(self, block, var): def _clone_var(self, block, var):
...@@ -528,6 +531,8 @@ class DistributeTranspiler: ...@@ -528,6 +531,8 @@ class DistributeTranspiler:
""" """
# step5 # step5
pserver_program = Program() pserver_program = Program()
print("param mapping on pserver: #### ",
self.param_grad_ep_mapping[endpoint]["params"])
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]: for v in self.param_grad_ep_mapping[endpoint]["grads"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册