未验证 提交 cec234b1 编写于 作者: S silingtong123 提交者: GitHub

test=develop, error message of tree_conv OP enhancement (#23574)

上级 7277df47
...@@ -60,40 +60,78 @@ class TreeConvOp : public framework::OperatorWithKernel { ...@@ -60,40 +60,78 @@ class TreeConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out")); OP_INOUT_CHECK(ctx->HasInput("NodesVector"), "Input", "NodesVector",
"TreeConv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "TreeConv");
OP_INOUT_CHECK(ctx->HasInput("EdgeSet"), "Input", "EdgeSet", "TreeConv");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TreeConv");
auto edge_dims = ctx->GetInputDim("EdgeSet"); auto edge_dims = ctx->GetInputDim("EdgeSet");
auto vector_dims = ctx->GetInputDim("NodesVector"); auto vector_dims = ctx->GetInputDim("NodesVector");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2"); PADDLE_ENFORCE_EQ(edge_dims[2], 2,
platform::errors::InvalidArgument(
"Input(EdgeSet) dim[2] should be 2. "
"But received Input(EdgeSet) dim[2] is %d.",
edge_dims[2]));
} else { } else {
if (edge_dims[2] != -1) { if (edge_dims[2] != -1) {
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2"); PADDLE_ENFORCE_EQ(edge_dims[2], 2,
platform::errors::InvalidArgument(
"Input(EdgeSet) dim[2] should be 2. "
"But received Input(EdgeSet) dim[2] is %d.",
edge_dims[2]));
} }
} }
PADDLE_ENFORCE_EQ(edge_dims.size(), 3, PADDLE_ENFORCE_EQ(edge_dims.size(), 3,
"The dimension of EdgeSet Tensor should be 3"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(vector_dims.size(), 3, "The dimension of EdgeSet Tensor should be 3. "
"The dimension of NodesVector Tensor should be 3"); "But received the dimension of EdgeSet Tensor is %d.",
edge_dims.size()));
PADDLE_ENFORCE_EQ(
vector_dims.size(), 3,
platform::errors::InvalidArgument(
"The dimension of NodesVector Tensor should be 3. "
"But received the dimension of NodesVector Tensor is %d.",
vector_dims.size()));
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
"The dimension of Filter Tensor should be 4"); platform::errors::InvalidArgument(
"The dimension of Filter Tensor should be 4. "
"But received the dimension of Filter Tensor is %d.",
filter_dims.size()));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3"); PADDLE_ENFORCE_EQ(filter_dims[1], 3,
platform::errors::InvalidArgument(
"Input(Filter) dim[1] should be 3. "
"But received Input(Filter) dim[1] is %d.",
filter_dims[1]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter_dims[0], vector_dims[2], filter_dims[0], vector_dims[2],
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"); platform::errors::InvalidArgument(
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]. "
"But received Input(Filter) dim[0] = %d, Input(NodesVector) "
"dim[2] = %d.",
filter_dims[0], vector_dims[2]));
} else { } else {
if (filter_dims[1] != -1) { if (filter_dims[1] != -1) {
PADDLE_ENFORCE_EQ(filter_dims[1], 3, PADDLE_ENFORCE_EQ(filter_dims[1], 3,
"Input(Filter) dim[1] should be 3"); platform::errors::InvalidArgument(
"Input(Filter) dim[1] should be 3. "
"But received Input(Filter) dim[1] is %d.",
filter_dims[1]));
} }
if (filter_dims[0] != -1 && vector_dims[2] != -1) { if (filter_dims[0] != -1 && vector_dims[2] != -1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter_dims[0], vector_dims[2], filter_dims[0], vector_dims[2],
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"); platform::errors::InvalidArgument(
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]. "
"But received Input(Filter) dim[0] = %d, Input(NodesVector) "
"dim[2] = %d.",
filter_dims[0], vector_dims[2]));
} }
} }
auto output_dims = framework::make_ddim( auto output_dims = framework::make_ddim(
...@@ -137,10 +175,21 @@ class TreeConvGradOp : public framework::OperatorWithKernel { ...@@ -137,10 +175,21 @@ class TreeConvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "grad_TreeConv");
OP_INOUT_CHECK(ctx->HasInput("EdgeSet"), "Input", "EdgeSet",
"grad_TreeConv");
OP_INOUT_CHECK(ctx->HasInput("NodesVector"), "Input", "NodesVector",
"grad_TreeConv");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "grad_TreeConv");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("NodesVector")),
"Output", framework::GradVarName("NodesVector"),
"grad_TreeConv");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Filter")), "Output",
framework::GradVarName("Filter"), "grad_TreeConv");
auto vectors_dims = ctx->GetInputDim("NodesVector"); auto vectors_dims = ctx->GetInputDim("NodesVector");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"the gradient of output(Out) must not be null");
if (ctx->HasOutput(framework::GradVarName("Filter"))) { if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
} }
......
...@@ -419,6 +419,9 @@ def tree_conv(nodes_vector, ...@@ -419,6 +419,9 @@ def tree_conv(nodes_vector,
# also output tensor could be pooling(the pooling in paper called global pooling) # also output tensor could be pooling(the pooling in paper called global pooling)
pooled = fluid.layers.reduce_max(out_vector, dim=2) # global pooling pooled = fluid.layers.reduce_max(out_vector, dim=2) # global pooling
""" """
check_type(nodes_vector, 'nodes_vector', (Variable), 'tree_conv')
check_type(edge_set, 'edge_set', (Variable), 'tree_conv')
helper = LayerHelper("tree_conv", **locals()) helper = LayerHelper("tree_conv", **locals())
dtype = helper.input_dtype('nodes_vector') dtype = helper.input_dtype('nodes_vector')
feature_size = nodes_vector.shape[2] feature_size = nodes_vector.shape[2]
......
...@@ -2949,6 +2949,8 @@ class TreeConv(layers.Layer): ...@@ -2949,6 +2949,8 @@ class TreeConv(layers.Layer):
is_bias=False) is_bias=False)
def forward(self, nodes_vector, edge_set): def forward(self, nodes_vector, edge_set):
check_type(nodes_vector, 'nodes_vector', (Variable), 'TreeConv')
check_type(edge_set, 'edge_set', (Variable), 'TreeConv')
if self._name: if self._name:
out = self.create_variable( out = self.create_variable(
name=self._name, dtype=self._dtype, persistable=False) name=self._name, dtype=self._dtype, persistable=False)
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from paddle.fluid.framework import program_guard, Program
from op_test import OpTest from op_test import OpTest
import unittest
import paddle.fluid as fluid
def collect_node_patch(og, max_depth): def collect_node_patch(og, max_depth):
...@@ -118,3 +120,45 @@ class TestTreeConvOp(OpTest): ...@@ -118,3 +120,45 @@ class TestTreeConvOp(OpTest):
], ],
axis=0) axis=0)
return vec return vec
class TestTreeConv_OpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
nodes_vector_1 = np.random.random((10, 5)).astype("float32")
edge_set_1 = fluid.layers.data(
name='edge_set_1', shape=[10, 2], dtype='float32')
# the nodes_vector of tree_conv must be Variable.
self.assertRaises(TypeError, fluid.contrib.layers.tree_conv,
nodes_vector_1, edge_set_1, 3)
nodes_vector_2 = fluid.layers.data(
name='vectors2', shape=[10, 5], dtype='float32')
edge_set_2 = np.random.random((10, 2)).astype("float32")
# the edge_set of tree_conv must be Variable.
self.assertRaises(TypeError, fluid.contrib.layers.tree_conv,
nodes_vector_2, edge_set_2, 3)
class TestDygraphTreeConv_OpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
TreeConv = fluid.dygraph.nn.TreeConv(
feature_size=5, output_size=6, num_filters=1, max_depth=2)
nodes_vector_1 = np.random.random((10, 5)).astype("float32")
edge_set_1 = fluid.layers.data(
name='edge_set_1', shape=[10, 2], dtype='float32')
# the nodes_vector of TreeConv must be Variable.
self.assertRaises(TypeError, TreeConv, nodes_vector_1, edge_set_1,
3)
nodes_vector_2 = fluid.layers.data(
name='vectors2', shape=[10, 5], dtype='float32')
edge_set_2 = np.random.random((10, 2)).astype("float32")
# the edge_set of TreeConv must be Variable.
self.assertRaises(TypeError, TreeConv, nodes_vector_2, edge_set_2,
3)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册