From 95cceb2dd7b32a62b83d4264154f8a0290018f03 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 12 Mar 2021 10:14:02 +0800 Subject: [PATCH] [CustomOp] Support duplicable op input and output (#31535) * support duplicable op inout * add costom concat op test --- .../extension/include/ext_op_meta_info.h | 169 +++++++++++---- paddle/fluid/framework/custom_operator.cc | 201 ++++++++++++++---- .../fluid/tests/custom_op/CMakeLists.txt | 3 + .../fluid/tests/custom_op/concat_and_split.h | 84 ++++++++ .../fluid/tests/custom_op/custom_concat_op.cc | 145 +++++++++++++ .../tests/custom_op/test_custom_concat.py | 148 +++++++++++++ .../custom_op/test_custom_relu_op_jit.py | 1 - .../utils/cpp_extension/extension_utils.py | 13 +- 8 files changed, 670 insertions(+), 94 deletions(-) create mode 100644 python/paddle/fluid/tests/custom_op/concat_and_split.h create mode 100644 python/paddle/fluid/tests/custom_op/custom_concat_op.cc create mode 100644 python/paddle/fluid/tests/custom_op/test_custom_concat.py diff --git a/paddle/fluid/extension/include/ext_op_meta_info.h b/paddle/fluid/extension/include/ext_op_meta_info.h index a3b9a4c491..5b8d5a0bf5 100644 --- a/paddle/fluid/extension/include/ext_op_meta_info.h +++ b/paddle/fluid/extension/include/ext_op_meta_info.h @@ -56,32 +56,48 @@ using Tensor = paddle::Tensor; ///////////////// Util Define and Function //////////////// -inline std::string Grad(const std::string& var_name) { +constexpr char kGradTensorSuffix[] = "@GRAD"; +constexpr char kTensorVectorSuffix[] = "@VECTOR"; + +// Used for Construct Grad Tensor name +inline std::string Grad(const std::string& t_name) { + std::string result; + result.reserve(t_name.size() + 5U); + result += t_name; + result += kGradTensorSuffix; + return result; +} + +// Used for Construct std::vector name +inline std::string Vec(const std::string& t_name) { std::string result; - result.reserve(var_name.size() + 5U); - result += var_name; - result += "@GRAD"; + result.reserve(t_name.size() + 7U); + result += t_name; + result += kTensorVectorSuffix; return result; } ////////////////////// Kernel Function (PD_KERNEL) //////////////////////// // Record Op kernel core function -using KernelFunc = std::vector (*)(std::vector inputs, - std::vector attrs); +using KernelFunc = std::vector (*)( + std::vector inputs, std::vector> vec_inputs, + std::vector attrs); #define PD_SPECIALIZE_ComputeCallHelper(attr_type) \ template \ struct ComputeCallHelper { \ - template \ + template \ static Return Compute(std::vector inputs, \ + std::vector> vec_inputs, \ std::vector attrs, \ const PreviousArgs&... pargs) { \ try { \ attr_type arg = boost::any_cast(attrs[attr_idx]); \ - return ComputeCallHelper::template Compute( \ - inputs, attrs, pargs..., arg); \ + return ComputeCallHelper::template Compute< \ + in_idx, vec_in_idx, attr_idx + 1>(inputs, vec_inputs, attrs, \ + pargs..., arg); \ } catch (boost::bad_any_cast&) { \ PD_THROW( \ "Attribute cast error in custom operator. Expected " #attr_type \ @@ -99,9 +115,10 @@ struct KernelFuncImpl; template struct KernelFuncImpl { static Return Compute(std::vector inputs, + std::vector> vec_inputs, std::vector attrs) { - return ComputeCallHelper>::template Compute<0, 0>( - inputs, attrs); + return ComputeCallHelper>::template Compute<0, 0, 0>( + inputs, vec_inputs, attrs); } private: @@ -111,15 +128,32 @@ struct KernelFuncImpl { // for Tensor input template struct ComputeCallHelper { - template + template static Return Compute(std::vector inputs, + std::vector> vec_inputs, std::vector attrs, const PreviousArgs&... pargs) { - static_assert(attr_idx == 0, - "Input tensor should appear before attributes."); const Tensor& arg = inputs[in_idx]; - return ComputeCallHelper::template Compute( - inputs, attrs, pargs..., arg); + return ComputeCallHelper::template Compute( + inputs, vec_inputs, attrs, pargs..., arg); + } + }; + + // for std::vector input + template + struct ComputeCallHelper&, Tail...> { + template + static Return Compute(std::vector inputs, + std::vector> vec_inputs, + std::vector attrs, + const PreviousArgs&... pargs) { + const std::vector& arg = vec_inputs[vec_in_idx]; + return ComputeCallHelper::template Compute< + in_idx, vec_in_idx + 1, attr_idx>(inputs, vec_inputs, attrs, pargs..., + arg); } }; @@ -140,8 +174,9 @@ struct KernelFuncImpl { // end: base template template struct ComputeCallHelper> { - template + template static Return Compute(std::vector inputs, + std::vector> vec_inputs, std::vector attrs, const Args&... args) { return impl_fn(args...); } @@ -155,40 +190,62 @@ struct KernelFuncImpl { // Record Op infershape core function using InferShapeFunc = std::vector> (*)( - std::vector> input_shapes); + std::vector> input_shapes, + std::vector>> vec_input_shapes); template struct InferShapeFuncImpl; template struct InferShapeFuncImpl { - static Return InferShape(std::vector> input_shapes) { - return InferShapeCallHelper>::template InferShape<0>( - input_shapes); + static Return InferShape( + std::vector> input_shapes, + std::vector>> vec_input_shapes) { + return InferShapeCallHelper>::template InferShape<0, + 0>( + input_shapes, vec_input_shapes); } private: template struct InferShapeCallHelper; - // only one type input: std::vector template struct InferShapeCallHelper, Tail...> { - template - static Return InferShape(std::vector> input_shapes, - const PreviousArgs&... pargs) { + template + static Return InferShape( + std::vector> input_shapes, + std::vector>> vec_input_shapes, + const PreviousArgs&... pargs) { std::vector arg = input_shapes[in_idx]; - return InferShapeCallHelper::template InferShape( - input_shapes, pargs..., arg); + return InferShapeCallHelper::template InferShape( + input_shapes, vec_input_shapes, pargs..., arg); + } + }; + + template + struct InferShapeCallHelper>, Tail...> { + template + static Return InferShape( + std::vector> input_shapes, + std::vector>> vec_input_shapes, + const PreviousArgs&... pargs) { + std::vector> arg = vec_input_shapes[vec_in_idx]; + return InferShapeCallHelper::template InferShape( + input_shapes, vec_input_shapes, pargs..., arg); } }; // end: base template template struct InferShapeCallHelper> { - template - static Return InferShape(std::vector> input_shapes, - const Args&... args) { + template + static Return InferShape( + std::vector> input_shapes, + std::vector>> vec_input_shapes, + const Args&... args) { return impl_fn(args...); } }; @@ -200,41 +257,63 @@ struct InferShapeFuncImpl { /////////////// InferDataType Function (PD_INFER_DTYPE) /////////////// // Record Op Infer dtype core function -using InferDtypeFunc = - std::vector (*)(std::vector input_dtypes); +using InferDtypeFunc = std::vector (*)( + std::vector input_dtypes, + std::vector> vec_input_dtypes); template struct InferDtypeFuncImpl; template struct InferDtypeFuncImpl { - static Return InferDtype(std::vector input_dtypes) { - return InferDtypeCallHelper>::template InferDtype<0>( - input_dtypes); + static Return InferDtype( + std::vector input_dtypes, + std::vector> vec_input_dtypes) { + return InferDtypeCallHelper>::template InferDtype<0, + 0>( + input_dtypes, vec_input_dtypes); } private: template struct InferDtypeCallHelper; - // Only one type input now: DataType template struct InferDtypeCallHelper { - template - static Return InferDtype(std::vector input_dtypes, - const PreviousArgs&... pargs) { + template + static Return InferDtype( + std::vector input_dtypes, + std::vector> vec_input_dtypes, + const PreviousArgs&... pargs) { DataType arg = input_dtypes[in_idx]; - return InferDtypeCallHelper::template InferDtype( - input_dtypes, pargs..., arg); + return InferDtypeCallHelper::template InferDtype( + input_dtypes, vec_input_dtypes, pargs..., arg); + } + }; + + template + struct InferDtypeCallHelper, Tail...> { + template + static Return InferDtype( + std::vector input_dtypes, + std::vector> vec_input_dtypes, + const PreviousArgs&... pargs) { + std::vector arg = vec_input_dtypes[vec_in_idx]; + return InferDtypeCallHelper::template InferDtype( + input_dtypes, vec_input_dtypes, pargs..., arg); } }; // end: base template template struct InferDtypeCallHelper> { - template - static Return InferDtype(std::vector input_dtypes, - const Args&... args) { + template + static Return InferDtype( + std::vector input_dtypes, + std::vector> vec_input_dtypes, + const Args&... args) { return impl_fn(args...); } }; diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 66e28bb83c..0baacd4621 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -27,7 +27,6 @@ limitations under the License. */ #include "paddle/fluid/extension/include/ext_tensor.h" #include "paddle/fluid/framework/attribute.h" -#include "paddle/fluid/framework/c/c_api.h" #include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_meta_info_helper.h" @@ -63,6 +62,11 @@ inline bool IsGradVar(const std::string& var_name) { return var_name.rfind(suffix) != std::string::npos; } +inline bool IsDuplicableVar(const std::string& var_name) { + std::string suffix = kTensorVectorSuffix; + return var_name.rfind(suffix) != std::string::npos; +} + inline std::string NoGrad(const std::string& var_name) { std::string suffix = kGradVarSuffix; return var_name.substr(0, var_name.size() - kGradVarSuffixSize); @@ -103,19 +107,47 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, const std::vector& attrs) { VLOG(1) << "Custom Operator: Start run KernelFunc."; std::vector custom_ins; + std::vector> custom_vec_ins; for (auto& in_name : inputs) { VLOG(1) << "Custom Operator: input name - " << in_name; - auto* x = ctx.Input(in_name); - PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound( - "Input tensor (%s) is nullptr.", in_name)); - PADDLE_ENFORCE_EQ(x->IsInitialized(), true, - platform::errors::InvalidArgument( - "Input tensor (%s) is not initialized.")); - auto custom_in = paddle::Tensor( - CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place())); - CustomTensorUtils::ShareDataFrom(static_cast(x), custom_in); - CustomTensorUtils::SetTensorCurrentStream(&custom_in, ctx.GetPlace()); - custom_ins.emplace_back(custom_in); + if (detail::IsDuplicableVar(in_name)) { + // return const std::vector + auto vec_x = ctx.MultiInput(in_name); + PADDLE_ENFORCE_NE(vec_x.empty(), true, + platform::errors::NotFound( + "Input vector (%s) is empty.", in_name)); + std::vector custom_vec_in; + for (size_t i = 0; i < vec_x.size(); ++i) { + auto* x = vec_x[i]; + PADDLE_ENFORCE_NOT_NULL( + x, platform::errors::NotFound( + "The %d-th tensor in input vector (%s) is nullptr.", + i, in_name)); + PADDLE_ENFORCE_EQ(x->IsInitialized(), true, + platform::errors::InvalidArgument( + "The %d-th tensor in input vector (%s) " + "is not initialized.", + i, in_name)); + auto custom_t = paddle::Tensor( + CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place())); + CustomTensorUtils::ShareDataFrom(static_cast(x), custom_t); + CustomTensorUtils::SetTensorCurrentStream(&custom_t, ctx.GetPlace()); + custom_vec_in.emplace_back(custom_t); + } + custom_vec_ins.emplace_back(custom_vec_in); + } else { + auto* x = ctx.Input(in_name); + PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound( + "Input tensor (%s) is nullptr.", in_name)); + PADDLE_ENFORCE_EQ(x->IsInitialized(), true, + platform::errors::InvalidArgument( + "Input tensor (%s) is not initialized.", in_name)); + auto custom_in = paddle::Tensor( + CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place())); + CustomTensorUtils::ShareDataFrom(static_cast(x), custom_in); + CustomTensorUtils::SetTensorCurrentStream(&custom_in, ctx.GetPlace()); + custom_ins.emplace_back(custom_in); + } } std::vector custom_attrs; @@ -153,14 +185,34 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, } } - VLOG(1) << "Run ComputeFunc."; + VLOG(1) << "Custom Operator: Run ComputeFunc."; try { - auto outs = func(custom_ins, custom_attrs); + auto outs = func(custom_ins, custom_vec_ins, custom_attrs); VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; for (size_t i = 0; i < outputs.size(); ++i) { - auto* true_out = ctx.Output(outputs[i]); - CustomTensorUtils::ShareDataTo(outs.at(i), true_out); + auto out_name = outputs[i]; + if (detail::IsDuplicableVar(out_name)) { + PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL, + platform::errors::PreconditionNotMet( + "If custom operator's outputs contains `paddle::Vec(" + ")` type, " + "it only can hold one output.")); + auto vec_true_outs = ctx.MultiOutput(out_name); + PADDLE_ENFORCE_EQ( + vec_true_outs.size(), outs.size(), + platform::errors::InvalidArgument( + "The number of element in custom operator outputs is wrong, " + "expected contains %d Tensors, but actually contains %d " + "Tensors.", + vec_true_outs.size(), outs.size())); + for (size_t j = 0; j < vec_true_outs.size(); ++j) { + CustomTensorUtils::ShareDataTo(outs.at(j), vec_true_outs.at(j)); + } + } else { + auto* true_out = ctx.Output(out_name); + CustomTensorUtils::ShareDataTo(outs.at(i), true_out); + } } } catch (platform::EnforceNotMet& exception) { throw std::move(exception); @@ -221,10 +273,20 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { void Make() override { for (auto& in_name : inputs_) { - AddInput(in_name, "The input " + in_name + "of Custom operator."); + if (detail::IsDuplicableVar(in_name)) { + AddInput(in_name, "The input " + in_name + "of Custom operator.") + .AsDuplicable(); + } else { + AddInput(in_name, "The input " + in_name + "of Custom operator."); + } } for (auto& out_name : outputs_) { - AddOutput(out_name, "The output " + out_name + "of Custom Operator."); + if (detail::IsDuplicableVar(out_name)) { + AddOutput(out_name, "The output " + out_name + "of Custom Operator.") + .AsDuplicable(); + } else { + AddOutput(out_name, "The output " + out_name + "of Custom Operator."); + } } for (auto& attr : attrs_) { auto attr_name_and_type = detail::ParseAttrStr(attr); @@ -331,7 +393,13 @@ class CustomGradOpMaker : public SingleGradOpMaker { } for (auto& out_name : outputs_) { VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name; - grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); + if (detail::IsDuplicableVar(out_name)) { + grad_op->SetOutput(out_name, + this->InputGrad(detail::NoGrad(out_name), + /*drop_empty_grad=*/false)); + } else { + grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); + } } grad_op->SetAttrMap(this->Attrs()); } @@ -493,9 +561,9 @@ void RegisterOperatorWithMetaInfo( platform::errors::Unavailable( "Your custom operator contains multiple inputs. " "We only allow a custom operator that contains only one input " - "and " - "only one output without setting the InferShapeFn. At this time, " - "the input shape will be directly set to the output shape.\n" + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" "Please set the InferShapeFn of custom " "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); PADDLE_ENFORCE_EQ( @@ -503,9 +571,9 @@ void RegisterOperatorWithMetaInfo( platform::errors::Unavailable( "Your custom operator contains multiple outputs. " "We only allow a custom operator that contains only one input " - "and " - "only one output without setting the InferShapeFn. At this time, " - "the input shape will be directly set to the output shape.\n" + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" "Please set the InferShapeFn of custom " "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); @@ -516,21 +584,46 @@ void RegisterOperatorWithMetaInfo( info.infer_shape_ = [op_inputs, op_outputs, infer_shape_func](InferShapeContext* ctx) { std::vector> input_shapes; + std::vector>> vec_input_shapes; VLOG(1) << "Custom Operator: InferShape - get input ddim."; for (auto& in_name : op_inputs) { - OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom"); - auto ddim = ctx->GetInputDim(in_name); - input_shapes.emplace_back(framework::vectorize(ddim)); + if (detail::IsDuplicableVar(in_name)) { + OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom"); + auto vec_ddim = ctx->GetInputsDim(in_name); + std::vector> vec_shape; + vec_shape.reserve(vec_ddim.size()); + std::transform(vec_ddim.begin(), vec_ddim.end(), + std::back_inserter(vec_shape), + [&](const DDim& ddim) -> std::vector { + return framework::vectorize(ddim); + }); + vec_input_shapes.emplace_back(vec_shape); + } else { + OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom"); + auto ddim = ctx->GetInputDim(in_name); + input_shapes.emplace_back(framework::vectorize(ddim)); + } } VLOG(1) << "Custom Operator: InferShape - calc output ddim."; - auto output_shapes = infer_shape_func(input_shapes); + auto output_shapes = infer_shape_func(input_shapes, vec_input_shapes); VLOG(1) << "Custom Operator: InferShape - set output ddim."; for (size_t i = 0; i < op_outputs.size(); ++i) { - ctx->SetOutputDim(op_outputs[i], - framework::make_ddim(output_shapes[i])); + auto out_name = op_outputs[i]; + if (detail::IsDuplicableVar(out_name)) { + std::vector vec_ddim; + vec_ddim.reserve(output_shapes.size()); + std::transform(output_shapes.begin(), output_shapes.end(), + std::back_inserter(vec_ddim), + [&](const std::vector& shape) -> DDim { + return framework::make_ddim(shape); + }); + ctx->SetOutputsDim(out_name, vec_ddim); + } else { + ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i])); + } } }; } @@ -544,9 +637,9 @@ void RegisterOperatorWithMetaInfo( platform::errors::Unavailable( "Your custom operator contains multiple inputs. " "We only allow a custom operator that contains only one input " - "and " - "only one output without setting the InferDtypeFn. At this time, " - "the input dtype will be directly set to the output dtype.\n" + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" "Please set the InferDtypeFn of custom " "operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))")); PADDLE_ENFORCE_EQ( @@ -554,9 +647,9 @@ void RegisterOperatorWithMetaInfo( platform::errors::Unavailable( "Your custom operator contains multiple outputs. " "We only allow a custom operator that contains only one input " - "and " - "only one output without setting the InferDtypeFn. At this time, " - "the input dtype will be directly set to the output dtype.\n" + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" "Please set the InferDtypeFn of custom " "operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))")); @@ -568,22 +661,42 @@ void RegisterOperatorWithMetaInfo( info.infer_var_type_ = [op_inputs, op_outputs, infer_dtype_func](InferVarTypeContext* ctx) { std::vector input_dtypes; + std::vector> vec_input_dtypes; VLOG(1) << "Custom Operator: InferDtype - get input dtype."; for (auto& in_name : op_inputs) { - auto dtype = ctx->GetInputDataType(in_name); - input_dtypes.emplace_back( - CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype)); + if (detail::IsDuplicableVar(in_name)) { + std::vector vec_custom_dtype; + for (size_t i = 0; i < ctx->InputSize(in_name); ++i) { + auto dtype = ctx->GetInputDataType(in_name, i); + vec_custom_dtype.emplace_back( + CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype)); + } + vec_input_dtypes.emplace_back(vec_custom_dtype); + } else { + auto dtype = ctx->GetInputDataType(in_name); + input_dtypes.emplace_back( + CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype)); + } } VLOG(1) << "Custom Operator: InferDtype - infer output dtype."; - auto output_dtypes = infer_dtype_func(input_dtypes); + auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes); VLOG(1) << "Custom Operator: InferDtype - set output dtype."; for (size_t i = 0; i < op_outputs.size(); ++i) { - ctx->SetOutputDataType( - op_outputs[i], - CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i])); + auto out_name = op_outputs[i]; + if (detail::IsDuplicableVar(out_name)) { + for (size_t j = 0; j < output_dtypes.size(); ++j) { + auto dtype = CustomTensorUtils::ConvertEnumDTypeToInnerDType( + output_dtypes[i]); + ctx->SetOutputDataType(out_name, dtype, j); + } + } else { + ctx->SetOutputDataType( + out_name, CustomTensorUtils::ConvertEnumDTypeToInnerDType( + output_dtypes[i])); + } } }; } diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index f57d22d871..620bff11a2 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -23,6 +23,9 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120) py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) +py_test(test_custom_concat SRCS test_custom_concat.py) +set_tests_properties(test_custom_concat PROPERTIES TIMEOUT 120) + py_test(test_check_abi SRCS test_check_abi.py) cc_test(test_check_error SRCS test_check_error.cc DEPS gtest) diff --git a/python/paddle/fluid/tests/custom_op/concat_and_split.h b/python/paddle/fluid/tests/custom_op/concat_and_split.h new file mode 100644 index 0000000000..9f24cc4369 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/concat_and_split.h @@ -0,0 +1,84 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/extension.h" + +int64_t GetRows(std::vector shape, int64_t axis) { + int64_t rows = 1; + for (int64_t i = 0; i < axis; ++i) { + rows *= shape[i]; + } + return rows; +} + +std::vector GetCols(const std::vector& ins, + int64_t rows, + int64_t* cols) { + std::vector cols_vec(ins.size()); + for (size_t i = 0; i < ins.size(); ++i) { + int64_t t_cols = ins[i].size() / rows; + *cols += t_cols; + cols_vec[i] = t_cols; + } + return cols_vec; +} + +template +void ConcatCpuKernel(const std::vector& ins, + paddle::Tensor* out, + int64_t axis) { + size_t num = ins.size(); + int64_t out_rows = GetRows(ins[0].shape(), axis); + int64_t out_cols = 0; + auto ins_cols = GetCols(ins, out_rows, &out_cols); + + auto* out_data = out->mutable_data(); + int64_t col_idx = 0; + for (size_t i = 0; i < num; ++i) { + int64_t col_len = ins_cols[i]; + auto* in_data = ins[i].data(); + for (int j = 0; j < out_rows; ++j) { + std::memcpy(out_data + j * out_cols + col_idx, + in_data + j * col_len, + sizeof(data_t) * col_len); + } + col_idx += col_len; + } +} + +template +void SplitCpuKernel(const paddle::Tensor& in, + const std::vector& ref_ins, + std::vector* outs, + int64_t axis) { + size_t num = outs->size(); + int64_t in_rows = GetRows(ref_ins[0].shape(), axis); + int64_t in_cols = 0; + auto out_cols = GetCols(ref_ins, in_rows, &in_cols); + + for (size_t i = 0; i < in_rows; ++i) { + auto* in_data = in.data() + i * in_cols; + int64_t col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int64_t col_len = out_cols[j]; + auto* out_data = outs->at(j).mutable_data() + i * col_len; + std::memcpy(out_data, in_data + col_idx, sizeof(data_t) * col_len); + col_idx += col_len; + } + } +} diff --git a/python/paddle/fluid/tests/custom_op/custom_concat_op.cc b/python/paddle/fluid/tests/custom_op/custom_concat_op.cc new file mode 100644 index 0000000000..4ea3930399 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_concat_op.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "concat_and_split.h" // NOLINT +#include "paddle/extension.h" + +#define CHECK_INPUT(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") + +int64_t ComputeAxis(int64_t axis, int64_t rank) { + PD_CHECK(axis >= -rank && axis < rank, + "The axis is excepted to be in range of [", + -rank, + ", ", + rank, + "]."); + if (axis < 0) { + axis = axis + rank; + } + return axis > 0 ? axis : 0; +} + +std::vector ComputeOutShape( + std::vector> in_shapes, int64_t axis) { + size_t n = in_shapes.size(); + auto out_shape = in_shapes[0]; + size_t zero_dim_size = out_shape.size(); + for (size_t i = 1; i < n; ++i) { + PD_CHECK(in_shapes[i].size() == out_shape.size(), + "Input dimension must be same."); + for (size_t j = 0; j < zero_dim_size; ++j) { + if (j == axis) { + out_shape[axis] += in_shapes[i][j]; + } else { + PD_CHECK(in_shapes[0][j] == in_shapes[i][j], + "The ", + j, + "-th dimension of input must be same."); + } + } + } + return out_shape; +} + +std::vector ConcatForwardDynamicAxis( + const std::vector& inputs, const paddle::Tensor& axis_t) { + // check inputs + PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat."); + for (auto& t : inputs) { + CHECK_INPUT(t); + } + CHECK_INPUT(axis_t); + + // compute output shape + int64_t rank = static_cast(inputs[0].shape().size()); + int64_t axis = axis_t.data()[0]; + axis = ComputeAxis(axis, rank); + std::vector> in_shapes; + for (auto& t : inputs) { + in_shapes.emplace_back(t.shape()); + } + auto out_shape = ComputeOutShape(in_shapes, axis); + + // create output + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(out_shape); + + // calc + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + inputs[0].type(), "ConcatCpuKernel", ([&] { + ConcatCpuKernel(inputs, &out, axis); + })); + + return {out}; +} + +std::vector ConcatBackwardDynamicAxis( + const std::vector& inputs, + const paddle::Tensor& grad_out, + const paddle::Tensor& axis_t) { + // check input + PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat."); + for (auto& t : inputs) { + CHECK_INPUT(t); + } + CHECK_INPUT(axis_t); + CHECK_INPUT(grad_out); + + // compate axis + int64_t rank = static_cast(inputs[0].shape().size()); + int64_t axis = axis_t.data()[0]; + axis = ComputeAxis(axis, rank); + + // create outputs + std::vector grad_inputs; + for (auto& t : inputs) { + auto grad = paddle::Tensor(paddle::PlaceType::kCPU); + grad.reshape(t.shape()); + grad_inputs.emplace_back(grad); + } + + // calc + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + grad_out.type(), "SplitCpuKernel", ([&] { + SplitCpuKernel(grad_out, inputs, &grad_inputs, axis); + })); + + return grad_inputs; +} + +std::vector> ConcatInferShapeDynamicAxis( + std::vector> input_shapes, + std::vector axis_shape) { + return {std::vector(input_shapes[0].size(), -1)}; +} + +std::vector ConcatInferDtypeDynamicAxis( + std::vector input_dtypes, paddle::DataType axis_dtype) { + return {input_dtypes[0]}; +} + +PD_BUILD_OP(custom_concat) + .Inputs({paddle::Vec("X"), "Axis"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ConcatForwardDynamicAxis)) + .SetInferShapeFn(PD_INFER_SHAPE(ConcatInferShapeDynamicAxis)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConcatInferDtypeDynamicAxis)); + +PD_BUILD_GRAD_OP(custom_concat) + .Inputs({paddle::Vec("X"), paddle::Grad("Out"), "Axis"}) + .Outputs({paddle::Grad(paddle::Vec("X"))}) + .SetKernelFn(PD_KERNEL(ConcatBackwardDynamicAxis)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_concat.py b/python/paddle/fluid/tests/custom_op/test_custom_concat.py new file mode 100644 index 0000000000..4086224cd7 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_concat.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np + +import paddle +import paddle.static as static +from paddle.utils.cpp_extension import load, get_build_directory +from paddle.utils.cpp_extension.extension_utils import run_cmd +from utils import paddle_includes, extra_cc_args, extra_nvcc_args + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_relu_module_jit\\custom_relu_module_jit.pyd'.format( + get_build_directory()) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) + +if os.name == 'nt': + test_include = "..\\python\\paddle\\fluid\\tests\\custom_op" +else: + test_include = "../python/paddle/fluid/tests/custom_op" +paddle_includes.append(test_include) + +custom_ops = load( + name='custom_concat_jit', + sources=['custom_concat_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True) + + +def concat_dynamic(func, device, dtype, np_inputs, axis_v): + paddle.set_device(device) + inputs = [ + paddle.to_tensor( + x, dtype=dtype, place=device, stop_gradient=False) + for x in np_inputs + ] + axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) + out = func(inputs, axis) + out.stop_gradient = False + out.backward() + grad_inputs = [x.grad for x in inputs] + return out.numpy(), grad_inputs + + +def concat_static(func, device, dtype, np_inputs, axis_v): + paddle.enable_static() + paddle.set_device(device) + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x1 = static.data(name="x1", shape=[2, 3], dtype=dtype) + x2 = static.data(name="x2", shape=[2, 3], dtype=dtype) + axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) + x1.stop_gradient = False + x2.stop_gradient = False + out = func([x1, x2], axis) + # mean only support float, so here use sum + sum_out = paddle.sum(out) + static.append_backward(sum_out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + out_v, x1_grad_v, x2_grad_v = exe.run( + static.default_main_program(), + feed={ + "x1": np_inputs[0].astype(dtype), + "x2": np_inputs[1].astype(dtype), + "axis": axis + }, + fetch_list=[out.name, x1.name + "@GRAD", x2.name + "@GRAD"]) + paddle.disable_static() + return out_v, x1_grad_v, x2_grad_v + + +class TestCustomConcatDynamicAxisJit(unittest.TestCase): + def setUp(self): + self.dtypes = ['float32', 'float64', 'int32', 'int64'] + self.devices = ['cpu'] + self.np_inputs = [ + np.array([[1, 2, 3], [4, 5, 6]]), + np.array([[11, 12, 13], [14, 15, 16]]) + ] + self.axises = [0, 1] + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + for axis in self.axises: + out, grad_inputs = concat_dynamic(custom_ops.custom_concat, + device, dtype, + self.np_inputs, axis) + pd_out, pd_grad_inputs = concat_dynamic( + paddle.concat, device, dtype, self.np_inputs, axis) + + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs): + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}". + format(x_grad, pd_x_grad)) + + def test_static(self): + for device in self.devices: + for dtype in self.dtypes: + for axis in self.axises: + out, x1_grad, x2_grad = concat_static( + custom_ops.custom_concat, device, dtype, self.np_inputs, + axis) + pd_out, pd_x1_grad, pd_x2_grad = concat_static( + paddle.concat, device, dtype, self.np_inputs, axis) + + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x1_grad, pd_x1_grad), + "custom op x1_grad: {},\n paddle api x1_grad: {}". + format(x1_grad, pd_x1_grad)) + self.assertTrue( + np.array_equal(x2_grad, pd_x2_grad), + "custom op x2_grad: {},\n paddle api x2_grad: {}". + format(x2_grad, pd_x2_grad)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 34cf38aacf..1a96fc5f0a 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import subprocess import unittest import paddle import numpy as np diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index fff92d85c8..b68100fe52 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -781,13 +781,18 @@ def _get_api_inputs_str(op_name): in_names, out_names, attr_names = parse_op_info(op_name) # e.g: x, y, z param_names = in_names + attr_names - params_str = ','.join([p.lower() for p in param_names]) + # NOTE(chenweihang): we add suffix `@VECTOR` for std::vector input, + # but the string contains `@` cannot used as argument name, so we split + # input name by `@`, and only use first substr as argument + params_str = ','.join([p.split("@")[0].lower() for p in param_names]) # e.g: {'X': x, 'Y': y, 'Z': z} - ins_str = "{%s}" % ','.join( - ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) + ins_str = "{%s}" % ','.join([ + "'{}' : {}".format(in_name, in_name.split("@")[0].lower()) + for in_name in in_names + ]) # e.g: {'num': n} attrs_str = "{%s}" % ",".join([ - "'{}' : {}".format(attr_name, attr_name.lower()) + "'{}' : {}".format(attr_name, attr_name.split("@")[0].lower()) for attr_name in attr_names ]) # e.g: ['Out', 'Index'] -- GitLab