diff --git a/cmake/operators.cmake b/cmake/operators.cmake index ecf2dbc81762a59d4d826ae8f5dfc0ab48a28910..e927fae63f0fc28902431c7b09350c7f7d10c52a 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -13,7 +13,7 @@ function(op_library TARGET) set(CUDNN_FILE) set(mkldnn_cc_srcs) set(MKLDNN_FILE) - set(op_common_deps operator op_registry math_function layer) + set(op_common_deps operator op_registry math_function layer common_infer_shape_functions) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DEPS) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 66fe71a80a7b0165a0d4afb38c89fc1fdb339190..78595e50b2da627065309041079839faa197cc8f 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_desc.h" + #include #include #include // NOLINT #include #include #include + #include "glog/logging.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_call_stack.h" @@ -51,6 +53,29 @@ class CompileTimeInferShapeContext : public InferShapeContext { std::vector Outputs(const std::string &name) const override; + std::string GetInputNameByIdx(size_t idx) const override { + auto &op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->inputs().size())); + return op_proto->inputs()[idx].name(); + } + + std::string GetOutputNameByIdx(size_t idx) const override { + auto &op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT( + idx, op_proto->outputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of outputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->outputs().size())); + return op_proto->outputs()[idx].name(); + } + void ShareDim(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9c293bcdb852ff1ab5b1494838ee2c947cd372cc..c8c18bcee6a8868919c584527c088725c1c9d58d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/operator.h" + #include #include @@ -20,13 +22,13 @@ limitations under the License. */ #include #include #include + #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_proto_maker.h" -#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" @@ -604,6 +606,29 @@ class RuntimeInferShapeContext : public InferShapeContext { return op_.Outputs(name); } + std::string GetInputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->inputs().size())); + return op_proto->inputs()[idx].name(); + } + + std::string GetOutputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT( + idx, op_proto->outputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of outputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->outputs().size())); + return op_proto->outputs()[idx].name(); + } + void ShareDim(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) override { auto in_it = ctx_.inputs.find(in); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 7ce8deb7cfc70d39de52e1fd9e5bace969f854e7..8d8a8f01b3f38c82a480bf7204721481586cc860 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include + #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/framework.pb.h" @@ -52,7 +53,8 @@ class InferShapeContext { const std::vector &dims) = 0; virtual void SetReaderDims(const std::string &name, const std::vector &dims); - + virtual std::string GetInputNameByIdx(size_t idx) const = 0; + virtual std::string GetOutputNameByIdx(size_t idx) const = 0; virtual AttrReader Attrs() const = 0; virtual std::vector Inputs(const std::string &name) const = 0; virtual std::vector Outputs(const std::string &name) const = 0; diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 65ac570bc7aa07a1a06e9deffcf797d6ef5d2519..fcd4545a2c82d3c64f8d8d8683438aaf0e6a2719 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -16,7 +16,9 @@ #include #include + #include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/imperative/type_defs.h" @@ -32,8 +34,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { public: DygraphInferShapeContext(const NameVarMap* in, const NameVarMap* out, - const framework::AttributeMap* attr) - : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {} + const framework::AttributeMap* attr, + const std::string op_type) + : var_base_map_in_(in), + var_base_map_out_(out), + attrs_(attr), + op_type_(op_type) {} bool HasInput(const std::string& name) const override { // has only one input @@ -135,6 +141,28 @@ class DygraphInferShapeContext : public framework::InferShapeContext { return vec_res; } + std::string GetInputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_; + PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + op_type_, idx, op_proto->inputs().size())); + return op_proto->inputs()[idx].name(); + } + + std::string GetOutputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_; + PADDLE_ENFORCE_LT( + idx, op_proto->outputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of outputs of " + "operator %s, but got index is %d and size is %d", + op_type_, idx, op_proto->outputs().size())); + return op_proto->outputs()[idx].name(); + } void ShareDim(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) override { @@ -367,6 +395,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const NameVarMap* var_base_map_in_; const NameVarMap* var_base_map_out_; const framework::AttributeMap* attrs_; + const std::string op_type_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index def5c860449214ad4a08fd69ff575b91d6f162a0..82b91d2e77292dbefae54d0f7ecb7a2aff00f979 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/imperative/prepared_operator.h" + #include + #include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_var_type_context.h" @@ -137,7 +139,8 @@ static void PreparedOpRunImpl( // TODO(zjl): remove scope in dygraph framework::Scope scope; - DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs); + DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, + op.Type()); static_cast(op).InferShape( &infer_shape_ctx); diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index a231e16100b9f6b153beffe7c66de6fc6813414e..4a30ffb7e3d01ffa90a42278e2e5ef5271045d8a 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -17,9 +17,11 @@ // #include + #include #include #include + #include "gtest/gtest.h" #include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/infer_shape_context.h" @@ -384,7 +386,7 @@ TEST(test_layer, test_dygraph_infershape_context) { concat_att_map["axis"] = 1; DygraphInferShapeContext infer_shape_ctx( - &ins, &outs, &concat_att_map); + &ins, &outs, &concat_att_map, "dummy"); bool have_x = infer_shape_ctx.HasOutputs("Out"); ASSERT_EQ(have_x, true); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 012b16a6a05f3d5fec3636b0a790d4d67334295f..e74f363d886e4601d07c1b2a7d79d8c915b59e93 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -86,12 +86,14 @@ if (WITH_DGC) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc) endif() +cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) endif() @@ -111,6 +113,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter) set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS}) set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") +cc_test(test_common_infer_shape_functions SRCS test_common_infer_shape_functions.cc DEPS common_infer_shape_functions ${COMMON_OP_DEPS} activation_op elementwise_add_op softmax_op softmax) cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op) cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 7ea78879e1e08a9690f4a7a966d3e7c3decd293b..107d333d3a8593e6e3d7afb38c7688d80f2441f8 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" + #include #include #include #include #include + +#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/platform/port.h" #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/common_infer_shape_functions.cc b/paddle/fluid/operators/common_infer_shape_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..22b212fc1b9f8844f0ae3555ac6d63af1f48d1cd --- /dev/null +++ b/paddle/fluid/operators/common_infer_shape_functions.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +#include +#include + +// This file almostly contains all the infershape functions that are used in +// operators. + +namespace paddle { +namespace operators { +namespace details { +inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, + const framework::DDim &y_dims, + int *x_dims_array, int *y_dims_array, + int *out_dims_array, const int max_dim, + const int axis) { + PADDLE_ENFORCE_GE( + axis, 0, + platform::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, max_dim, + platform::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", + max_dim, axis)); + if (x_dims.size() > y_dims.size()) { + std::fill(y_dims_array, y_dims_array + axis, 1); + if (axis + y_dims.size() < max_dim) { + std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1); + } + std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array); + std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array + axis); + } else { + std::fill(x_dims_array, x_dims_array + axis, 1); + if (axis + x_dims.size() < max_dim) { + std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1); + } + std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis); + std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array); + } + + for (int i = 0; i < max_dim; i++) { + PADDLE_ENFORCE_EQ( + x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || + y_dims_array[i] <= 1, + true, platform::errors::InvalidArgument( + "Broadcast dimension mismatch. Operands could " + "not be broadcast together with the shape of X = [%s] and " + "the shape of Y = [%s]. Received [%d] in X is not equal to " + "[%d] in Y at i:%d.", + x_dims, y_dims, x_dims_array[i], y_dims_array[i], i)); + if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || + (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { + out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } else { + out_dims_array[i] = -1; + } + } +} +} // namespace details + +// shape input(0) -> output(0) without change. +void UnaryOpUnchangedInferShape(framework::InferShapeContext *ctx) { + auto x_name = ctx->GetInputNameByIdx(0); + auto out_name = ctx->GetOutputNameByIdx(0); + ctx->ShareDim(x_name, /*->*/ out_name); + ctx->ShareLoD(x_name, /*->*/ out_name); +} + +// shape input(0) -> output(0) without change, check if axis in range [-Rank(x), +// Rank(x)-1] +void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext *ctx) { + auto x_name = ctx->GetInputNameByIdx(0); + auto out_name = ctx->GetOutputNameByIdx(0); + auto x_dim = ctx->GetInputDim(x_name); + auto x_rank = x_dim.size(); + auto axis = ctx->Attrs().Get("axis"); + PADDLE_ENFORCE_GE( + axis, -x_rank, + platform::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X). But received axis: %d, R: %d.", + axis, x_rank)); + PADDLE_ENFORCE_LT( + axis, x_rank, + platform::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X). But received axis: %d, R: %d.", + axis, x_rank)); + ctx->ShareDim(x_name, /*->*/ out_name); + ctx->ShareLoD(x_name, /*->*/ out_name); +} + +// broadcast input(0) and input(1) -> output(0) +void BinaryOpBroadcastInferShape(framework::InferShapeContext *ctx) { + auto x_name = ctx->GetInputNameByIdx(0); + auto y_name = ctx->GetInputNameByIdx(1); + auto out_name = ctx->GetOutputNameByIdx(0); + auto x_dims = ctx->GetInputDim(x_name); + auto y_dims = ctx->GetInputDim(y_name); + PADDLE_ENFORCE_EQ( + ctx->GetInputsVarType(y_name).front(), + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The var type of input %s should be LoDTensor, but got %s.", + ctx->Inputs(y_name).front(), ctx->GetInputsVarType(y_name).front())); + + if (ctx->GetInputsVarType(x_name).front() == + framework::proto::VarType::SELECTED_ROWS) { + PADDLE_ENFORCE_EQ(y_dims.size(), 1u, + platform::errors::InvalidArgument( + "For binary broadcastable operator, if X is " + "Sparse(VarType.SELECTED_ROWS" + "), Y must be scalar, and the size of Y should be 1. " + "But reveived the size of Y = %s.", + y_dims.size())); + PADDLE_ENFORCE_EQ( + y_dims[0], 1, + platform::errors::InvalidArgument( + "For binary broadcastable operator, if X is " + "Sparse(VarType.SELECTED_ROWS" + "), Y must be scalar, the first dimension of Y should be 1. " + "But reveived the first dimension of Y = %s.", + y_dims[0])); + } else if (ctx->GetInputsVarType(x_name).front() != + framework::proto::VarType::LOD_TENSOR) { + PADDLE_THROW(platform::errors::InvalidArgument( + "For binary broadcastable operator, the var type of input X should " + "be LOD_TENSOR, but got %s", + ctx->GetInputsVarType(x_name).front())); + } + + if (x_dims == y_dims) { + ctx->ShareDim(x_name, /*->*/ out_name); + ctx->ShareLoD(x_name, /*->*/ out_name); + } else { + int max_dim = std::max(x_dims.size(), y_dims.size()); + int axis = ctx->Attrs().Get("axis"); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + details::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), + max_dim, axis); + ctx->SetOutputDim(out_name, framework::make_ddim(out_dims_array)); + ctx->ShareLoD(x_name, /*->*/ out_name); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/common_infer_shape_functions.h b/paddle/fluid/operators/common_infer_shape_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..2cb9eab2865ce068a4f776bc63070c59bf029481 --- /dev/null +++ b/paddle/fluid/operators/common_infer_shape_functions.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +// This file almostly contains all the infershape functions that are used in +// operators. + +namespace paddle { +namespace operators { + +// shape input(0) -> output(0) without change. +void UnaryOpUnchangedInferShape(framework::InferShapeContext* ctx); +// shape input(0) -> output(0) without change, check if axis in range [-Rank(x), +// Rank(x)-1] +void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext* ctx); +// broadcast input(0) and input(1) -> output(0) +void BinaryOpBroadcastInferShape(framework::InferShapeContext* ctx); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index d14f8ae65feaec1c6d536cb18f366d96e137108b..ece6af1b5a6f562bd7ff81290f98e8636feabb0c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -19,9 +19,11 @@ limitations under the License. */ #include #include #include + #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 7c77b2688e7b528f678418c67e77fa4abff04248..0adf61d7ce3e5b5792b9dc65d5ac8f884dc81ea5 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/selu_op.h" + #include #include #include +#include "paddle/fluid/operators/common_infer_shape_functions.h" + namespace paddle { namespace operators { @@ -28,11 +31,7 @@ class SeluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "selu"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "selu"); - - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); + return UnaryOpUnchangedInferShape(ctx); } protected: diff --git a/paddle/fluid/operators/test_common_infer_shape_functions.cc b/paddle/fluid/operators/test_common_infer_shape_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca8f6ce84fc571674fdfe6f29cbcd82a98fd8fcf --- /dev/null +++ b/paddle/fluid/operators/test_common_infer_shape_functions.cc @@ -0,0 +1,145 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/imperative/infer_shape_context.h" +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +USE_OP(relu); +USE_OP(elementwise_add); +USE_OP(softmax); + +namespace paddle { +namespace operators { +namespace details { + +class DygraphInferShapeTest { + public: + void AddInput(const std::string& name, const framework::DDim& dim) { + std::shared_ptr vin( + new imperative::VarBase(false, name)); + vin->MutableVar()->GetMutable()->Resize(dim); + ins_[name] = {vin}; + } + void AddOutput(const std::string& name, const framework::DDim& expected_dim) { + std::shared_ptr vout( + new imperative::VarBase(false, name)); + vout->MutableVar() + ->GetMutable(); // InitializeVariable + outs_[name] = {vout}; + expected_dims_[name] = expected_dim; + } + void AddAttrs(const framework::AttributeMap& attrs) { attrs_ = attrs; } + void SetOpType(const std::string& op_type) { op_type_ = op_type; } + void Run(std::function infer_shape) { + imperative::DygraphInferShapeContext ctx( + &ins_, &outs_, &attrs_, op_type_); + infer_shape(&ctx); + for (const auto& pair : expected_dims_) { + auto out = outs_[pair.first][0]; + ASSERT_EQ(pair.second, + out->MutableVar()->GetMutable()->dims()); + } + } + + private: + imperative::NameVarBaseMap ins_; + imperative::NameVarBaseMap outs_; + framework::AttributeMap attrs_; + std::string op_type_; + std::map expected_dims_; +}; +} // namespace details + +TEST(test_UnaryOpUnchangedInferShape, test_shape) { + details::DygraphInferShapeTest test; + test.AddInput("X", {2, 10}); + test.AddOutput("Out", {2, 10}); + test.SetOpType("relu"); + test.Run(UnaryOpUnchangedInferShape); +} + +TEST(test_BinaryOpBroadcastInferShape, test_same_shape) { + details::DygraphInferShapeTest test; + test.AddInput("X", {2, 3, 4, 5}); + test.AddInput("Y", {2, 3, 4, 5}); + test.AddOutput("Out", {2, 3, 4, 5}); + test.SetOpType("elementwise_add"); + test.Run(BinaryOpBroadcastInferShape); +} + +TEST(test_BinaryOpBroadcastInferShape, test_broadcast1) { + details::DygraphInferShapeTest test; + test.AddInput("X", {2, 3, 4, 5}); + test.AddInput("Y", {4, 5}); + test.AddOutput("Out", {2, 3, 4, 5}); + test.AddAttrs({ + {"axis", -1}, + }); + test.SetOpType("elementwise_add"); + test.Run(BinaryOpBroadcastInferShape); +} + +TEST(test_BinaryOpBroadcastInferShape, test_broadcast2) { + details::DygraphInferShapeTest test; + test.AddInput("X", {2, 10, 5, 1}); + test.AddInput("Y", {10, 1, 1}); + test.AddOutput("Out", {2, 10, 5, 1}); + test.AddAttrs({ + {"axis", -1}, + }); + test.SetOpType("elementwise_add"); + test.Run(BinaryOpBroadcastInferShape); +} + +TEST(test_BinaryOpBroadcastInferShape, test_broadcast3) { + details::DygraphInferShapeTest test; + test.AddInput("X", {10, 1, 1}); + test.AddInput("Y", {2, 10, 5, 5}); + test.AddOutput("Out", {2, 10, 5, 5}); + test.AddAttrs({ + {"axis", -1}, + }); + test.SetOpType("elementwise_add"); + test.Run(BinaryOpBroadcastInferShape); +} + +TEST(test_UnaryOpUnchangedInferShapeCheckAxis, test_shape) { + details::DygraphInferShapeTest test; + test.AddInput("X", {2, 10}); + test.AddOutput("Out", {2, 10}); + test.AddAttrs({ + {"axis", -1}, + }); + test.SetOpType("softmax"); + test.Run(UnaryOpUnchangedInferShapeCheckAxis); +} + +TEST(test_UnaryOpUnchangedInferShapeCheckAxis, test_axis_exception) { + details::DygraphInferShapeTest test; + test.AddInput("X", {2, 10}); + test.AddOutput("Out", {2, 10}); + test.AddAttrs({ + {"axis", 2}, + }); + test.SetOpType("softmax"); + ASSERT_ANY_THROW(test.Run(UnaryOpUnchangedInferShapeCheckAxis)); +} + +} // namespace operators +} // namespace paddle