diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 77c7c855c0ffed5032e639237b01037a990652c4..cb401402f9d972bf72cfc286119675ce413c64cd 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 19ec9ba9b26f5919796181a19a048b7edb508bdd..c96a697a7e022684688b31c05da43e52812100d8 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -80,6 +80,15 @@ void Copy(platform::GPUPlace dst_place, platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); } +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(dst_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); +} + #endif // PADDLE_ONLY_CPU } // namespace memory diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 01cbfc33efcb4042438fbb398fbcca9457f1334f..1ffa02c8f94c01a385d3ba376c1fd0dc3c1bd372 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -25,12 +25,14 @@ class ConcatOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + "Inputs(X) of ConcatOp should be empty.") PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of ConcatOp should not be null."); auto ins = ctx->GetInputsDim("X"); size_t axis = static_cast(ctx->Attrs().Get("axis")); - size_t n = ins.size(); + const size_t n = ins.size(); PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); @@ -72,10 +74,27 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class ConcatOpGrad : public framework::OperatorWithKernel { + public: + ConcatOpGrad(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker) +REGISTER_OP(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad, + ops::ConcatOpGrad) REGISTER_OP_CPU_KERNEL(concat, ops::ConcatKernel) +REGISTER_OP_CPU_KERNEL(concat_grad, + ops::ConcatGradKernel) diff --git a/paddle/operators/concat_op.cu b/paddle/operators/concat_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ede832ddcd486729db56bba016683b33875f8837 --- /dev/null +++ b/paddle/operators/concat_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/concat_op.h" +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(concat, + ops::ConcatKernel); +REGISTER_OP_GPU_KERNEL( + concat_grad, ops::ConcatGradKernel); diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h index f977054fdf8aa0164db726b94a21c57f770dd674..b37063261123bce1f22c39ab021e88f2faf58e9f 100644 --- a/paddle/operators/concat_op.h +++ b/paddle/operators/concat_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/framework/op_registry.h" +#include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { @@ -27,35 +28,39 @@ class ConcatKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto* out = ctx.Output("Out"); int64_t axis = static_cast(ctx.Attr("axis")); - size_t n = ins.size(); - size_t output_axis_dim = 0; - size_t before = 1, after = 1; - for (size_t i = 0; i < n; i++) { - output_axis_dim += ins[i]->dims()[axis]; - } - auto& input_zero = ins[0]; - for (int64_t i = 0; i < input_zero->dims().size(); i++) { - if (i == axis) { - continue; - } - if (i < axis) { - before *= input_zero->dims()[i]; - } else { - after *= input_zero->dims()[i]; - } - } + const size_t n = ins.size(); size_t output_offset = 0; + out->mutable_data(ctx.GetPlace()); + auto out_stride = framework::stride(out->dims()); for (size_t i = 0; i < n; i++) { auto& in = ins[i]; auto axis_dim = in->dims()[axis]; - for (size_t j = 0; j < before; j++) { - size_t len = axis_dim * after * sizeof(T); - const T* src = in->data() + axis_dim * after * j; - T* out_data = out->mutable_data(platform::CPUPlace()); - T* dest = out_data + output_offset + output_axis_dim * after * j; - memcpy(dest, src, len); - } - output_offset += axis_dim * after; + auto in_stride = framework::stride(in->dims()); + StridedMemcpy(ctx.device_context(), in->data(), in_stride, + in->dims(), out_stride, out->data() + output_offset); + output_offset += axis_dim * in_stride[axis]; + } + } +}; + +template +class ConcatGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* in = ctx.Input(framework::GradVarName("Out")); + auto outs = ctx.MultiOutput(framework::GradVarName("X")); + int64_t axis = static_cast(ctx.Attr("axis")); + const size_t n = outs.size(); + size_t input_offset = 0; + auto in_stride = framework::stride(in->dims()); + for (size_t i = 0; i < n; i++) { + auto& out = outs[i]; + out->mutable_data(ctx.GetPlace()); + size_t axis_dim = out->dims()[axis]; + auto out_stride = framework::stride(out->dims()); + StridedMemcpy(ctx.device_context(), in->data() + input_offset, + in_stride, out->dims(), out_stride, out->data()); + input_offset += axis_dim * in_stride[axis]; } } }; diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc index 8640d1010ef6ae352a93ee2fd7b771a90c6efa5c..5f4b5539affef6fe1d3c4d15fff77d983b5e107f 100644 --- a/paddle/operators/split_op.cc +++ b/paddle/operators/split_op.cc @@ -25,6 +25,10 @@ class SplitOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SplitOp should not be null."); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + "Outputs(Out) of SplitOp should not be empty."); auto in_dims = ctx->GetInputDim("X"); auto outs_names = ctx->Outputs("Out"); size_t axis = static_cast(ctx->Attrs().Get("axis")); @@ -55,9 +59,6 @@ class SplitOp : public framework::OperatorWithKernel { dim[axis] = sections[i]; outs_dims.push_back(dim); } - } else { - PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", - " specify indices or sections."); } ctx->SetOutputsDim("Out", outs_dims); } @@ -117,4 +118,4 @@ USE_CPU_ONLY_OP(concat); REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad, ops::SplitOpGrad); REGISTER_OP_CPU_KERNEL(split, - ops::SplitKernel); + ops::SplitOpKernel); diff --git a/paddle/operators/split_op.cu b/paddle/operators/split_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..93d1fc3c44cbc146c945c51af1abe6494572d1ae --- /dev/null +++ b/paddle/operators/split_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/split_op.h" +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(split, + ops::SplitOpKernel); diff --git a/paddle/operators/split_op.h b/paddle/operators/split_op.h index 860690ee895075fda9ddef08776a2102642efff9..8ab8e0ee4fea621b34da73507c53846100d61a17 100644 --- a/paddle/operators/split_op.h +++ b/paddle/operators/split_op.h @@ -16,44 +16,29 @@ limitations under the License. */ #include #include "paddle/framework/op_registry.h" +#include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { template -class SplitKernel : public framework::OpKernel { +class SplitOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto outs = ctx.MultiOutput("Out"); + auto in_stride = framework::stride(in->dims()); int64_t axis = static_cast(ctx.Attr("axis")); - size_t before = 1, after = 1; const size_t n = outs.size(); - size_t input_axis_dim = in->dims()[axis]; - - for (int64_t i = 0; i < in->dims().size(); ++i) { - if (i == axis) { - continue; - } - if (i < axis) { - before *= in->dims()[i]; - } else { - after *= in->dims()[i]; - } - } size_t input_offset = 0; for (size_t i = 0; i < n; i++) { auto& out = outs[i]; + out->mutable_data(ctx.GetPlace()); size_t axis_dim = out->dims()[axis]; - for (size_t j = 0; j < before; j++) { - size_t len = axis_dim * after * sizeof(T); - T* dest = - out->mutable_data(platform::CPUPlace()) + axis_dim * after * j; - const T* src = - in->data() + input_offset + input_axis_dim * after * j; - memcpy(dest, src, len); - } - input_offset += axis_dim * after; + auto out_stride = framework::stride(out->dims()); + StridedMemcpy(ctx.device_context(), in->data() + input_offset, + in_stride, out->dims(), out_stride, out->data()); + input_offset += axis_dim * in_stride[axis]; } } }; diff --git a/python/paddle/v2/framework/tests/test_concat_op.py b/python/paddle/v2/framework/tests/test_concat_op.py index 656563f96e52df30951ec0ec7042ad9c530e90b2..a792d1c106ac00efd92e680cfad67f41a7520e26 100644 --- a/python/paddle/v2/framework/tests/test_concat_op.py +++ b/python/paddle/v2/framework/tests/test_concat_op.py @@ -6,10 +6,10 @@ from op_test import OpTest class TestConcatOp(OpTest): def setUp(self): self.op_type = "concat" - x0 = np.random.random((2, 3, 2, 5)).astype('float32') - x1 = np.random.random((2, 3, 3, 5)).astype('float32') + x0 = np.random.random((2, 1, 4, 5)).astype('float32') + x1 = np.random.random((2, 2, 4, 5)).astype('float32') x2 = np.random.random((2, 3, 4, 5)).astype('float32') - axis = 2 + axis = 1 self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]} self.attrs = {'axis': axis} self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)} @@ -17,6 +17,9 @@ class TestConcatOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_split_op.py b/python/paddle/v2/framework/tests/test_split_op.py index b4420db9d71b99556e305104ac17ef5e4b4bd0f2..37c6ebb89d1c3bcfc3c80a54a1e92c0326e046e3 100644 --- a/python/paddle/v2/framework/tests/test_split_op.py +++ b/python/paddle/v2/framework/tests/test_split_op.py @@ -7,11 +7,10 @@ class TestSplitOp(OpTest): def setUp(self): self.op_type = "split" axis = 0 - num = 2 - x = np.random.random((4, 2)).astype('float32') - out = np.split(x, num, axis) + x = np.random.random((4, 2, 5)).astype('float32') + out = np.split(x, [1, 3], axis) self.inputs = {'X': x} - self.attrs = {'axis': axis, 'num': num} + self.attrs = {'axis': axis, 'sections': [1, 2, 1]} self.outputs = {'Out': [('out%d' % i, out[i]) \ for i in xrange(len(out))]} @@ -19,7 +18,7 @@ class TestSplitOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1']) + self.check_grad(['X'], ['out0', 'out1', 'out2']) if __name__ == '__main__':