未验证 提交 ffe52b44 编写于 作者: L Leo Chen 提交者: GitHub

[OpDevOptimize] Add common infershape functions (#26096)

* add unchaged infershape function

* add broadcast infershape function

* fix bug

* rename infershape functions

* add UnaryOpUnchangedInferShapeCheckAxis

* add error message

* add test for common infer shape functions

* dont update existed ops

* dont update op_desc.h

* add more test

* add error check, refine error message
上级 2d95280e
......@@ -13,7 +13,7 @@ function(op_library TARGET)
set(CUDNN_FILE)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
set(op_common_deps operator op_registry math_function layer)
set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
......
......@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_desc.h"
#include <algorithm>
#include <functional>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_call_stack.h"
......@@ -51,6 +53,29 @@ class CompileTimeInferShapeContext : public InferShapeContext {
std::vector<std::string> Outputs(const std::string &name) const override;
std::string GetInputNameByIdx(size_t idx) const override {
auto &op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto &op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(
idx, op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
......@@ -20,13 +22,13 @@ limitations under the License. */
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h"
......@@ -604,6 +606,29 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name);
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(
idx, op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
auto in_it = ctx_.inputs.find(in);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
......@@ -52,7 +53,8 @@ class InferShapeContext {
const std::vector<DDim> &dims) = 0;
virtual void SetReaderDims(const std::string &name,
const std::vector<DDim> &dims);
virtual std::string GetInputNameByIdx(size_t idx) const = 0;
virtual std::string GetOutputNameByIdx(size_t idx) const = 0;
virtual AttrReader Attrs() const = 0;
virtual std::vector<std::string> Inputs(const std::string &name) const = 0;
virtual std::vector<std::string> Outputs(const std::string &name) const = 0;
......
......@@ -16,7 +16,9 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
......@@ -32,8 +34,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
public:
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr)
: var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {}
const framework::AttributeMap* attr,
const std::string op_type)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
op_type_(op_type) {}
bool HasInput(const std::string& name) const override {
// has only one input
......@@ -135,6 +141,28 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return vec_res;
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_;
PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_type_, idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_;
PADDLE_ENFORCE_LT(
idx, op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_type_, idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
......@@ -367,6 +395,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
const std::string op_type_;
};
} // namespace imperative
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/imperative/prepared_operator.h"
#include <sstream>
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/infer_var_type_context.h"
......@@ -137,7 +139,8 @@ static void PreparedOpRunImpl(
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs);
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
......
......@@ -17,9 +17,11 @@
//
#include <paddle/fluid/framework/op_registry.h>
#include <memory>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
......@@ -384,7 +386,7 @@ TEST(test_layer, test_dygraph_infershape_context) {
concat_att_map["axis"] = 1;
DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx(
&ins, &outs, &concat_att_map);
&ins, &outs, &concat_att_map, "dummy");
bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true);
......
......@@ -86,12 +86,14 @@ if (WITH_DGC)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc)
endif()
cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
endif()
......@@ -111,6 +113,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter)
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS})
set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(test_common_infer_shape_functions SRCS test_common_infer_shape_functions.cc DEPS common_infer_shape_functions ${COMMON_OP_DEPS} activation_op elementwise_add_op softmax_op softmax)
cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
......
......@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include <algorithm>
#include <vector>
// This file almostly contains all the infershape functions that are used in
// operators.
namespace paddle {
namespace operators {
namespace details {
inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
const framework::DDim &y_dims,
int *x_dims_array, int *y_dims_array,
int *out_dims_array, const int max_dim,
const int axis) {
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim, axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1);
}
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array + axis);
} else {
std::fill(x_dims_array, x_dims_array + axis, 1);
if (axis + x_dims.size() < max_dim) {
std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1);
}
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
}
for (int i = 0; i < max_dim; i++) {
PADDLE_ENFORCE_EQ(
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
y_dims_array[i] <= 1,
true, platform::errors::InvalidArgument(
"Broadcast dimension mismatch. Operands could "
"not be broadcast together with the shape of X = [%s] and "
"the shape of Y = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d.",
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i));
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
} else {
out_dims_array[i] = -1;
}
}
}
} // namespace details
// shape input(0) -> output(0) without change.
void UnaryOpUnchangedInferShape(framework::InferShapeContext *ctx) {
auto x_name = ctx->GetInputNameByIdx(0);
auto out_name = ctx->GetOutputNameByIdx(0);
ctx->ShareDim(x_name, /*->*/ out_name);
ctx->ShareLoD(x_name, /*->*/ out_name);
}
// shape input(0) -> output(0) without change, check if axis in range [-Rank(x),
// Rank(x)-1]
void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext *ctx) {
auto x_name = ctx->GetInputNameByIdx(0);
auto out_name = ctx->GetOutputNameByIdx(0);
auto x_dim = ctx->GetInputDim(x_name);
auto x_rank = x_dim.size();
auto axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE_GE(
axis, -x_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis, x_rank));
PADDLE_ENFORCE_LT(
axis, x_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis, x_rank));
ctx->ShareDim(x_name, /*->*/ out_name);
ctx->ShareLoD(x_name, /*->*/ out_name);
}
// broadcast input(0) and input(1) -> output(0)
void BinaryOpBroadcastInferShape(framework::InferShapeContext *ctx) {
auto x_name = ctx->GetInputNameByIdx(0);
auto y_name = ctx->GetInputNameByIdx(1);
auto out_name = ctx->GetOutputNameByIdx(0);
auto x_dims = ctx->GetInputDim(x_name);
auto y_dims = ctx->GetInputDim(y_name);
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType(y_name).front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The var type of input %s should be LoDTensor, but got %s.",
ctx->Inputs(y_name).front(), ctx->GetInputsVarType(y_name).front()));
if (ctx->GetInputsVarType(x_name).front() ==
framework::proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE_EQ(y_dims.size(), 1u,
platform::errors::InvalidArgument(
"For binary broadcastable operator, if X is "
"Sparse(VarType.SELECTED_ROWS"
"), Y must be scalar, and the size of Y should be 1. "
"But reveived the size of Y = %s.",
y_dims.size()));
PADDLE_ENFORCE_EQ(
y_dims[0], 1,
platform::errors::InvalidArgument(
"For binary broadcastable operator, if X is "
"Sparse(VarType.SELECTED_ROWS"
"), Y must be scalar, the first dimension of Y should be 1. "
"But reveived the first dimension of Y = %s.",
y_dims[0]));
} else if (ctx->GetInputsVarType(x_name).front() !=
framework::proto::VarType::LOD_TENSOR) {
PADDLE_THROW(platform::errors::InvalidArgument(
"For binary broadcastable operator, the var type of input X should "
"be LOD_TENSOR, but got %s",
ctx->GetInputsVarType(x_name).front()));
}
if (x_dims == y_dims) {
ctx->ShareDim(x_name, /*->*/ out_name);
ctx->ShareLoD(x_name, /*->*/ out_name);
} else {
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx->Attrs().Get<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
details::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
ctx->SetOutputDim(out_name, framework::make_ddim(out_dims_array));
ctx->ShareLoD(x_name, /*->*/ out_name);
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
// This file almostly contains all the infershape functions that are used in
// operators.
namespace paddle {
namespace operators {
// shape input(0) -> output(0) without change.
void UnaryOpUnchangedInferShape(framework::InferShapeContext* ctx);
// shape input(0) -> output(0) without change, check if axis in range [-Rank(x),
// Rank(x)-1]
void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext* ctx);
// broadcast input(0) and input(1) -> output(0)
void BinaryOpBroadcastInferShape(framework::InferShapeContext* ctx);
} // namespace operators
} // namespace paddle
......@@ -19,9 +19,11 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/selu_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
namespace paddle {
namespace operators {
......@@ -28,11 +31,7 @@ class SeluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "selu");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "selu");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
return UnaryOpUnchangedInferShape(ctx);
}
protected:
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
USE_OP(relu);
USE_OP(elementwise_add);
USE_OP(softmax);
namespace paddle {
namespace operators {
namespace details {
class DygraphInferShapeTest {
public:
void AddInput(const std::string& name, const framework::DDim& dim) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, name));
vin->MutableVar()->GetMutable<framework::LoDTensor>()->Resize(dim);
ins_[name] = {vin};
}
void AddOutput(const std::string& name, const framework::DDim& expected_dim) {
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, name));
vout->MutableVar()
->GetMutable<framework::LoDTensor>(); // InitializeVariable
outs_[name] = {vout};
expected_dims_[name] = expected_dim;
}
void AddAttrs(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetOpType(const std::string& op_type) { op_type_ = op_type; }
void Run(std::function<void(framework::InferShapeContext* ctx)> infer_shape) {
imperative::DygraphInferShapeContext<imperative::VarBase> ctx(
&ins_, &outs_, &attrs_, op_type_);
infer_shape(&ctx);
for (const auto& pair : expected_dims_) {
auto out = outs_[pair.first][0];
ASSERT_EQ(pair.second,
out->MutableVar()->GetMutable<framework::LoDTensor>()->dims());
}
}
private:
imperative::NameVarBaseMap ins_;
imperative::NameVarBaseMap outs_;
framework::AttributeMap attrs_;
std::string op_type_;
std::map<std::string, framework::DDim> expected_dims_;
};
} // namespace details
TEST(test_UnaryOpUnchangedInferShape, test_shape) {
details::DygraphInferShapeTest test;
test.AddInput("X", {2, 10});
test.AddOutput("Out", {2, 10});
test.SetOpType("relu");
test.Run(UnaryOpUnchangedInferShape);
}
TEST(test_BinaryOpBroadcastInferShape, test_same_shape) {
details::DygraphInferShapeTest test;
test.AddInput("X", {2, 3, 4, 5});
test.AddInput("Y", {2, 3, 4, 5});
test.AddOutput("Out", {2, 3, 4, 5});
test.SetOpType("elementwise_add");
test.Run(BinaryOpBroadcastInferShape);
}
TEST(test_BinaryOpBroadcastInferShape, test_broadcast1) {
details::DygraphInferShapeTest test;
test.AddInput("X", {2, 3, 4, 5});
test.AddInput("Y", {4, 5});
test.AddOutput("Out", {2, 3, 4, 5});
test.AddAttrs({
{"axis", -1},
});
test.SetOpType("elementwise_add");
test.Run(BinaryOpBroadcastInferShape);
}
TEST(test_BinaryOpBroadcastInferShape, test_broadcast2) {
details::DygraphInferShapeTest test;
test.AddInput("X", {2, 10, 5, 1});
test.AddInput("Y", {10, 1, 1});
test.AddOutput("Out", {2, 10, 5, 1});
test.AddAttrs({
{"axis", -1},
});
test.SetOpType("elementwise_add");
test.Run(BinaryOpBroadcastInferShape);
}
TEST(test_BinaryOpBroadcastInferShape, test_broadcast3) {
details::DygraphInferShapeTest test;
test.AddInput("X", {10, 1, 1});
test.AddInput("Y", {2, 10, 5, 5});
test.AddOutput("Out", {2, 10, 5, 5});
test.AddAttrs({
{"axis", -1},
});
test.SetOpType("elementwise_add");
test.Run(BinaryOpBroadcastInferShape);
}
TEST(test_UnaryOpUnchangedInferShapeCheckAxis, test_shape) {
details::DygraphInferShapeTest test;
test.AddInput("X", {2, 10});
test.AddOutput("Out", {2, 10});
test.AddAttrs({
{"axis", -1},
});
test.SetOpType("softmax");
test.Run(UnaryOpUnchangedInferShapeCheckAxis);
}
TEST(test_UnaryOpUnchangedInferShapeCheckAxis, test_axis_exception) {
details::DygraphInferShapeTest test;
test.AddInput("X", {2, 10});
test.AddOutput("Out", {2, 10});
test.AddAttrs({
{"axis", 2},
});
test.SetOpType("softmax");
ASSERT_ANY_THROW(test.Run(UnaryOpUnchangedInferShapeCheckAxis));
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册