提交 d7db15f3 编写于 作者: Y Yancey 提交者: GitHub

Use StridedMemCpy in Concat/Split Kernel (#4188)

User StridedMemCpy in Concat/Split Op
上级 5deeefed
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <atomic>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
......
...@@ -80,6 +80,15 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place, ...@@ -80,6 +80,15 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
} }
template <>
void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
}
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} // namespace memory } // namespace memory
......
...@@ -25,12 +25,14 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -25,12 +25,14 @@ class ConcatOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.")
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null."); "Output(Out) of ConcatOp should not be null.");
auto ins = ctx->GetInputsDim("X"); auto ins = ctx->GetInputsDim("X");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t n = ins.size(); const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
...@@ -72,10 +74,27 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -72,10 +74,27 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class ConcatOpGrad : public framework::OperatorWithKernel {
public:
ConcatOpGrad(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker) REGISTER_OP(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad)
REGISTER_OP_CPU_KERNEL(concat, REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>) ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/concat_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
concat_grad, ops::ConcatGradKernel<paddle::platform::GPUPlace, float>);
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,35 +28,39 @@ class ConcatKernel : public framework::OpKernel { ...@@ -27,35 +28,39 @@ class ConcatKernel : public framework::OpKernel {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t n = ins.size(); const size_t n = ins.size();
size_t output_axis_dim = 0;
size_t before = 1, after = 1;
for (size_t i = 0; i < n; i++) {
output_axis_dim += ins[i]->dims()[axis];
}
auto& input_zero = ins[0];
for (int64_t i = 0; i < input_zero->dims().size(); i++) {
if (i == axis) {
continue;
}
if (i < axis) {
before *= input_zero->dims()[i];
} else {
after *= input_zero->dims()[i];
}
}
size_t output_offset = 0; size_t output_offset = 0;
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride(out->dims());
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
auto& in = ins[i]; auto& in = ins[i];
auto axis_dim = in->dims()[axis]; auto axis_dim = in->dims()[axis];
for (size_t j = 0; j < before; j++) { auto in_stride = framework::stride(in->dims());
size_t len = axis_dim * after * sizeof(T); StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
const T* src = in->data<T>() + axis_dim * after * j; in->dims(), out_stride, out->data<T>() + output_offset);
T* out_data = out->mutable_data<T>(platform::CPUPlace()); output_offset += axis_dim * in_stride[axis];
T* dest = out_data + output_offset + output_axis_dim * after * j; }
memcpy(dest, src, len); }
} };
output_offset += axis_dim * after;
template <typename Place, typename T>
class ConcatGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
size_t input_offset = 0;
auto in_stride = framework::stride(in->dims());
for (size_t i = 0; i < n; i++) {
auto& out = outs[i];
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis];
auto out_stride = framework::stride(out->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>());
input_offset += axis_dim * in_stride[axis];
} }
} }
}; };
......
...@@ -25,6 +25,10 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -25,6 +25,10 @@ class SplitOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SplitOp should not be null.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of SplitOp should not be empty.");
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out"); auto outs_names = ctx->Outputs("Out");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
...@@ -55,9 +59,6 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -55,9 +59,6 @@ class SplitOp : public framework::OperatorWithKernel {
dim[axis] = sections[i]; dim[axis] = sections[i];
outs_dims.push_back(dim); outs_dims.push_back(dim);
} }
} else {
PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should",
" specify indices or sections.");
} }
ctx->SetOutputsDim("Out", outs_dims); ctx->SetOutputsDim("Out", outs_dims);
} }
...@@ -117,4 +118,4 @@ USE_CPU_ONLY_OP(concat); ...@@ -117,4 +118,4 @@ USE_CPU_ONLY_OP(concat);
REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad, REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad,
ops::SplitOpGrad); ops::SplitOpGrad);
REGISTER_OP_CPU_KERNEL(split, REGISTER_OP_CPU_KERNEL(split,
ops::SplitKernel<paddle::platform::CPUPlace, float>); ops::SplitOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/split_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(split,
ops::SplitOpKernel<paddle::platform::GPUPlace, float>);
...@@ -16,44 +16,29 @@ limitations under the License. */ ...@@ -16,44 +16,29 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SplitKernel : public framework::OpKernel { class SplitOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride(in->dims());
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t before = 1, after = 1;
const size_t n = outs.size(); const size_t n = outs.size();
size_t input_axis_dim = in->dims()[axis];
for (int64_t i = 0; i < in->dims().size(); ++i) {
if (i == axis) {
continue;
}
if (i < axis) {
before *= in->dims()[i];
} else {
after *= in->dims()[i];
}
}
size_t input_offset = 0; size_t input_offset = 0;
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
auto& out = outs[i]; auto& out = outs[i];
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis]; size_t axis_dim = out->dims()[axis];
for (size_t j = 0; j < before; j++) { auto out_stride = framework::stride(out->dims());
size_t len = axis_dim * after * sizeof(T); StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
T* dest = in_stride, out->dims(), out_stride, out->data<T>());
out->mutable_data<T>(platform::CPUPlace()) + axis_dim * after * j; input_offset += axis_dim * in_stride[axis];
const T* src =
in->data<T>() + input_offset + input_axis_dim * after * j;
memcpy(dest, src, len);
}
input_offset += axis_dim * after;
} }
} }
}; };
......
...@@ -6,10 +6,10 @@ from op_test import OpTest ...@@ -6,10 +6,10 @@ from op_test import OpTest
class TestConcatOp(OpTest): class TestConcatOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "concat" self.op_type = "concat"
x0 = np.random.random((2, 3, 2, 5)).astype('float32') x0 = np.random.random((2, 1, 4, 5)).astype('float32')
x1 = np.random.random((2, 3, 3, 5)).astype('float32') x1 = np.random.random((2, 2, 4, 5)).astype('float32')
x2 = np.random.random((2, 3, 4, 5)).astype('float32') x2 = np.random.random((2, 3, 4, 5)).astype('float32')
axis = 2 axis = 1
self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]} self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]}
self.attrs = {'axis': axis} self.attrs = {'axis': axis}
self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)} self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)}
...@@ -17,6 +17,9 @@ class TestConcatOp(OpTest): ...@@ -17,6 +17,9 @@ class TestConcatOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -7,11 +7,10 @@ class TestSplitOp(OpTest): ...@@ -7,11 +7,10 @@ class TestSplitOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "split" self.op_type = "split"
axis = 0 axis = 0
num = 2 x = np.random.random((4, 2, 5)).astype('float32')
x = np.random.random((4, 2)).astype('float32') out = np.split(x, [1, 3], axis)
out = np.split(x, num, axis)
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'axis': axis, 'num': num} self.attrs = {'axis': axis, 'sections': [1, 2, 1]}
self.outputs = {'Out': [('out%d' % i, out[i]) \ self.outputs = {'Out': [('out%d' % i, out[i]) \
for i in xrange(len(out))]} for i in xrange(len(out))]}
...@@ -19,7 +18,7 @@ class TestSplitOp(OpTest): ...@@ -19,7 +18,7 @@ class TestSplitOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1']) self.check_grad(['X'], ['out0', 'out1', 'out2'])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册