提交 c3864eab 编写于 作者: C chengduoZH

if axis == 0; directly copy D->D

上级 131ec276
...@@ -33,14 +33,26 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -33,14 +33,26 @@ class ConcatKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
out->mutable_data<T>(place); out->mutable_data<T>(place);
// TODO(zcd): Sometimes direct copies will be faster // Sometimes direct copies will be faster, this maybe need deeply analysis.
std::vector<framework::Tensor> inputs(ins.size()); if (axis == 0 && ins.size() < 10) {
for (size_t j = 0; j < ins.size(); ++j) { size_t output_offset = 0;
inputs[j] = *ins[j]; for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<framework::Tensor> inputs(ins.size());
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
} }
}; };
...@@ -52,17 +64,31 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -52,17 +64,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// TODO(zcd): Sometimes direct copies will be faster // Sometimes direct copies will be faster, this maybe need deeply analysis.
std::vector<framework::Tensor> outputs(outs.size()); if (axis == 0 && outs.size() < 10) {
for (size_t j = 0; j < outs.size(); ++j) { size_t input_offset = 0;
outs[j]->mutable_data<T>(ctx.GetPlace()); auto in_stride = framework::stride_numel(in->dims());
outputs[j] = *outs[j];
} for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
}
} else {
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs[j] = *outs[j];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T> paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor; concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs); concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册