提交 6ac32d09 编写于 作者: J Jiawei Wang 提交者: Dong Daxiang

Instag Implemention (#18394)

* instag lod tensor impl

* First PR for instag

* First PR for instag

* Before adding Selection Rows.

* Change name from instag to filter_instag, add upgrade the impl of filter_instag

* Change name from instag to filter_instag, add upgrade the impl of filter_instag

* Fix yapf error in gradient_checker.py to pass Travis-CI

* Fix Filter Instag Grad test=develop

* Fix Filter Instag Grad test=develop

* 1) Fix API.spec, add filter_instag Op. 2) Add Vector Support for CUDA. test=develop

* Impl Loss_weight and empty output handler

* change Loss Weight datatype to Float32, and add Loss Weight as 2nd output

* 1) Support Tensor Input(without LOD) 2) Add Unit test

* Filter By Instag Final test=develop

* Update API.spec for filter_by_instag test=develop

* Update API.spec for filter_by_instag 2 test=develop

* Add Filter By Instag Coverage

* code format of test_layers.py

* code format test_layers.py test=develop

* Make API args more readable test=develop

* Make API args more readable and pass code format test=develop

* Filter By Instag Op, Rename Map to Index Map test=develop

* Filter By Instag Op, code format err in filter_by_instag_op.cc  test=develop

* Filter by instag op: code format of cpp files test=develop

* Filter by instag Op: Api spec modification test=develop

* Filter by instag Op: Api spec doc id modification test=develop

* Filter by instag Op: Api spec and doc preview  test=develop test=document_preview

* Filter By Instag Op, fix doc erro test=document_preview test=develop

* Filter By Instag Op, fix doc err and Api spec test=document_preview test=develop

* Filter By Instag Op, fix Api spec test=document_preview test=develop

* Filter By Instag Op, fix Paddle Encoforce deprecated warning test=document_preview test=develop

* Filter By Instag Op, fix Paddle Encoforce deprecated and code format warning test=document_preview test=develop
上级 2e76e755
......@@ -268,6 +268,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459'))
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/filter_by_instag_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace operators {
class FilterByInstagOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Ins"), true,
"Input(Ins) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Ins_tag"), true,
"Input(Ins_tag) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter_tag"), true,
"Input(Filter_tag) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("LossWeight"), true,
"Output(LossWeight) shoudl not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("IndexMap"), true,
"Output(IndexMap) should be not null.");
auto x1_dims = ctx->GetInputDim("Ins"); // batch_size * vec
ctx->SetOutputDim("Out", framework::make_ddim({-1, x1_dims[1]}));
ctx->SetOutputDim("LossWeight", framework::make_ddim({-1, 1}));
ctx->SetOutputDim("IndexMap", framework::make_ddim({-1, 2}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Ins"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class FilterByInstagOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ins", "(LoDTensor) embeded tensor");
AddInput("Ins_tag", "(LoDTensor) ins tag list");
AddInput("Filter_tag", "(1D Tensor) filter tag list");
AddAttr<bool>("is_lod", "is Ins with LoD info or not, default True");
AddOutput("Out", "(LoDTensor) embeded tensor filtered by instag");
AddOutput("LossWeight", "(Tensor) loss weight.");
AddOutput("IndexMap", "(LoDTensor) mapping from Out rows to X1 rows");
AddComment(R"DOC(
Filter By Instag Op
This operator is used to filter embeded ins.
There are 3 inputs. First is embeded ins, Second is tags for ins,
Third is tags to filter.
There are 3 outputs. First is filtered embeded ins, Second is Loss Weight,
Third is the IndexMap from Out line number to X1 line number.
)DOC");
}
};
class FilterByInstagOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("IndexMap"), true,
"Input(IndexMap) should be not null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Grad Input(Out) should be not null");
PADDLE_ENFORCE_EQ(ctx->HasInput("Ins"), true,
"Input(Ins) should be not null");
PADDLE_ENFORCE_EQ(ctx->HasInput("LossWeight"), true,
"Input(LossWeight) should be not null");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Ins")), true,
"Grad Output(Ins) should be not null");
auto grad_out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x1_dims = ctx->GetInputDim("Ins");
ctx->SetOutputDim(framework::GradVarName("Ins"),
framework::make_ddim({x1_dims[0], grad_out_dims[1]}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class FilterByInstagGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("filter_by_instag_grad");
op->SetInput("IndexMap", Output("IndexMap"));
op->SetInput("Ins", Input("Ins"));
op->SetAttrMap(Attrs());
op->SetInput("LossWeight", Output("LossWeight"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Ins"), InputGrad("Ins"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(filter_by_instag, ops::FilterByInstagOp,
ops::FilterByInstagOpMaker,
ops::FilterByInstagGradOpDescMaker);
REGISTER_OPERATOR(filter_by_instag_grad, ops::FilterByInstagOpGrad);
REGISTER_OP_CPU_KERNEL(filter_by_instag, ops::FilterByInstagKernel<float>,
ops::FilterByInstagKernel<double>,
ops::FilterByInstagKernel<int32_t>,
ops::FilterByInstagKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(filter_by_instag_grad,
ops::FilterByInstagGradKernel<float>,
ops::FilterByInstagGradKernel<double>,
ops::FilterByInstagGradKernel<int32_t>,
ops::FilterByInstagGradKernel<int64_t>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/assert.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows;
using LoDTensor = framework::LoDTensor;
#if defined(PADDLE_WITH_CUDA)
template <typename T>
using Vector = framework::Vector<T>;
#else
template <typename T>
using Vector = framework::CPUVector<T>;
#endif
template <typename T>
class FilterByInstagKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// X1 is global FC output
// Dim [batch size, embedding size]
auto* x1 = context.Input<LoDTensor>("Ins");
bool is_x1_lod = context.Attr<bool>("is_lod");
// X2 is ins tag list
// LoD [[0, Sum(ins1), Sum(ins1, ins2), ... ]]
auto* x2 = context.Input<LoDTensor>("Ins_tag");
// X3 is local fc tag list
// LoD [[0, Sum(fc1), Sum(fc1, fc2) ...]]
auto* x3 = context.Input<Tensor>("Filter_tag");
std::unordered_set<int64_t> filter_tag;
auto* x3_data = x3->data<int64_t>();
size_t len = x3->dims()[0];
for (size_t i = 0; i < len; i++) {
filter_tag.insert(x3_data[i]);
}
// expected auto = const int64_t
auto* x2_data = x2->data<int64_t>();
// e.g get [0, 1, 2, 3, ...]
auto x2_lods = x2->lod()[0];
Vector<size_t> x1_lods(1, 0);
if (!is_x1_lod) {
for (size_t i = 0; i < x1->dims()[0]; i++) {
x1_lods.push_back(i + 1);
}
} else {
x1_lods = context.Input<LoDTensor>("Ins")->lod()[0];
}
std::unordered_map<int64_t, int64_t> mmap_aux;
std::vector<size_t> ins_after_filter;
Vector<size_t> out_lods(1, 0);
for (size_t i = 0; i < x2_lods.size() - 1; i++) {
for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) {
if (filter_tag.find(x2_data[j]) != filter_tag.end()) {
ins_after_filter.push_back(x2_lods[i]);
size_t batch_len = x1_lods[i + 1] - x1_lods[i];
mmap_aux[out_lods.back()] = x1_lods[i];
out_lods.push_back(out_lods.back() + batch_len);
break;
}
}
}
// set output value
// for those whose ins been dropout, set 0 for whole lines.
// otherwise, copy whole line
// Dim [local fc count, batch size, embedding size]
LoDTensor* out = context.Output<LoDTensor>("Out");
LoDTensor* map = context.Output<LoDTensor>("IndexMap");
LoDTensor* loss_weight = context.Output<LoDTensor>("LossWeight");
// expected auto = const T
auto* x1_data = x1->data<T>();
// expected auto = T
size_t x1_embed_size = x1->dims()[1];
if (ins_after_filter.size() > 0) {
out->Resize(framework::make_ddim(
{(int64_t)out_lods.back(), (int64_t)x1_embed_size}));
map->Resize(framework::make_ddim({(int64_t)ins_after_filter.size(), 3}));
loss_weight->Resize(
framework::make_ddim({(int64_t)ins_after_filter.size(), 1}));
} else {
out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size}));
map->Resize(framework::make_ddim({1, 3}));
loss_weight->Resize(framework::make_ddim({1, 1}));
}
auto* out_data = out->mutable_data<T>(context.GetPlace());
auto* map_data = map->mutable_data<int64_t>(context.GetPlace());
auto* loss_weight_data =
loss_weight->mutable_data<float>(context.GetPlace());
if (ins_after_filter.size() > 0) {
Vector<size_t> map_lods;
for (size_t i = 0; i < ins_after_filter.size(); i++) {
map_data[i * 3] = (int64_t)out_lods[i];
map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]];
map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i];
map_lods.push_back(i);
}
map_lods.push_back(ins_after_filter.size());
std::vector<Vector<size_t>> map_lod_info;
map_lod_info.push_back(map_lods);
map->set_lod(map_lod_info);
loss_weight->set_lod(map_lod_info);
std::vector<Vector<size_t>> out_lod_info;
out_lod_info.push_back(out_lods);
out->set_lod(out_lod_info);
memset(out_data, 0, out->numel() * sizeof(T));
for (size_t i = 0; i < loss_weight->numel(); i++) {
loss_weight_data[i] = 1;
}
for (size_t i = 0; i < ins_after_filter.size(); i++) {
size_t pos = out_lods[i];
for (size_t k = x1_lods[ins_after_filter[i]];
k < x1_lods[ins_after_filter[i] + 1]; k++) {
memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size,
x1_embed_size * sizeof(T));
++pos;
}
}
} else {
Vector<size_t> map_lods;
map_data[0] = 0;
map_data[1] = 1;
map_data[2] = 1;
map_lods.push_back(0);
map_lods.push_back(1);
out_lods.push_back(1);
std::vector<Vector<size_t>> map_lod_info;
map_lod_info.push_back(map_lods);
map->set_lod(map_lod_info);
loss_weight->set_lod(map_lod_info);
std::vector<Vector<size_t>> out_lod_info;
out_lod_info.push_back(out_lods);
out->set_lod(out_lod_info);
memset(out_data, 0, out->numel() * sizeof(T));
loss_weight_data[0] = 0;
}
}
};
template <typename T>
class FilterByInstagGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x1_grad = context.Output<LoDTensor>(framework::GradVarName("Ins"));
auto* loss_weight = context.Input<LoDTensor>("LossWeight");
auto* mmap = context.Input<LoDTensor>("IndexMap");
auto* x1 = context.Input<LoDTensor>("Ins");
x1_grad->set_lod(context.Input<LoDTensor>("Ins")->lod());
x1_grad->Resize(x1->dims());
auto mmap_data = mmap->data<int64_t>();
// expected auto = T
auto* output_grad_data = output_grad->data<T>();
auto* loss_weight_data = loss_weight->data<float>();
// expected auto = T
auto* x1_grad_data = x1_grad->mutable_data<T>(context.GetPlace());
memset(x1_grad_data, 0, x1->dims()[0] * x1->dims()[1] * sizeof(T));
if (loss_weight->numel() != 1 || loss_weight_data[0] != 0) {
auto output_dims = output_grad->dims();
for (size_t i = 0; i < mmap->dims()[0]; i++) {
int src_ln = mmap_data[i * 3], dst_ln = mmap_data[i * 3 + 1];
int line_cnt = mmap_data[i * 3 + 2];
for (size_t l = 0; l < line_cnt; l++) {
for (size_t j = 0; j < output_dims[1]; j++) {
x1_grad_data[(dst_ln + l) * output_dims[1] + j] =
output_grad_data[(src_ln + l) * output_dims[1] + j];
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -211,6 +211,7 @@ __all__ = [
'deformable_conv',
'unfold',
'deformable_roi_pooling',
'filter_by_instag',
'var_conv_2d',
'shard_index',
'hard_swish',
......@@ -9754,6 +9755,76 @@ def stack(x, axis=0):
return out
@templatedoc(op_type="filter_by_instag")
def filter_by_instag(ins, ins_tag, filter_tag, is_lod):
"""
**Filter By Instag Layer**
This function filter a batch of ins by instag,
There are multiple ins, and every ins belongs to some tags.
We can specify some tags we want. So the ins which belongs to that tags
remains in the output, and others removed.
For example, one batch has 4 ins. Every ins has its tag list.
| Ins | Ins_Tag |
|:-----:|:------:|
| 0 | 0, 1 |
| 1 | 1, 3 |
| 2 | 0, 3 |
| 3 | 2, 6 |
And Lod is [1,1,1,1]
And the filter tags [1]
From the definition above, ins which has tag 1 can pass the filter
So Ins 0 and Ins 1 can pass and be seen in the output,
Ins 2 and 3 cannot pass because they do not has tag 1.
Actually, if is_lod is false, it is normal tensor that equals to
lod_tensor with all 1, similar to the example above.
Args:
ins (Variable): Input Variable (LoDTensor), usually it is 2D tensor
And first dimension can have lod info or not.
ins_tag (Variable): Input Variable (LoDTensor), usually it is 1D list
And split them by lod info
filter_tag (Variable): Input Variable (1D Tensor/List), usually it is
list that holds the tags.
is_lod (Bool): Boolean value to indicate ins is lod tensor or not.
Returns:
Variable: filtered ins (LoDTensor) and loss weight (Tensor)
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
ins = layers.data(name='Ins', shape=[-1,32], lod_level=0, dtype='float64')
ins_tag = layers.data(name='Ins_tag', shape=[-1,16], lod_level=0, dtype='int64')
filter_tag = layers.data(name='Filter_tag', shape=[-1,16], dtype='int64')
out, loss_weight = layers.filter_by_instag(ins, ins_tag, filter_tag, True)
"""
helper = LayerHelper('filter_by_instag', **locals())
out = helper.create_variable_for_type_inference(dtype=ins.dtype)
loss_weight = helper.create_variable_for_type_inference(dtype=np.float64)
mmap = helper.create_variable_for_type_inference(dtype=ins_tag.dtype)
helper.append_op(
type='filter_by_instag',
inputs={'Ins': ins,
'Ins_tag': ins_tag,
'Filter_tag': filter_tag},
outputs={'Out': out,
'LossWeight': loss_weight,
'IndexMap': mmap},
attrs={'is_lod': is_lod})
return [out, loss_weight]
def unstack(x, axis=0, num=None):
"""
**UnStack Layer**
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is the lib for gradient checker unittest."""
from __future__ import print_function
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is unit test of Test filter_instag Op."""
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from op_test import OpTest
import random
from decorator_helper import prog_scope
"""This is Test Case 1"""
class TestFilterByInstagOp(OpTest):
def setUp(self):
self.op_type = 'filter_by_instag'
x1 = np.zeros((36, 4), dtype=np.float64)
for i in range(36):
for j in range(4):
x1[i, j] = i
x1_lod = [[1, 2, 3, 4, 5, 6, 7, 8]]
x2 = np.array([[1], [2], [1], [2], [1], [2], [1], [2]]).astype('int64')
x2_lod = [[1, 1, 1, 1, 1, 1, 1, 1]]
x3 = np.array([2]).astype('int64')
out = np.zeros((20, 4), dtype=np.float64)
out_lod = [[2, 4, 6, 8]]
start_num_lst = [1, 6, 15, 28]
ln = 0
for i in range(4):
start = start_num_lst[i]
len = out_lod[0][i]
for j in range(len):
cur = start + j
for k in range(4):
out[ln, k] = cur
ln += 1
mmap = np.array(
[[0, 1, 2], [2, 6, 4], [6, 15, 6], [12, 28, 8]]).astype('int64')
mmap_lod = [[1, 1, 1, 1]]
loss_weight = np.array([[1], [1], [1], [1]]).astype('double')
self.inputs = {
'Ins': (x1, x1_lod),
'Ins_tag': (x2, x2_lod),
'Filter_tag': x3,
}
self.outputs = {
'Out': (out, out_lod),
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Ins'], 'Out', no_grad_set=set(['Ins_tag', 'Filter_tag']))
"""This is Test Case 2"""
class TestFilterByInstagOp2(OpTest):
def setUp(self):
self.op_type = 'filter_by_instag'
batch_size = 4
x1_embed_size = 4
fc_cnt = 2
x1 = np.array([[10, 13, 12, 1], [1, 1, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1]]).astype('double')
x1_lod = [[1, 1, 1, 1]]
x2 = np.array([[2], [1], [2], [1]]).astype('int64')
x2_lod = [[1, 1, 1, 1]]
x3 = np.array([1]).astype('int64')
out = np.array([[1, 1, 1, 1], [1, 1, 1, 1]]).astype('double')
out_lod = [[1, 1]]
mmap = np.array([[0, 1, 1], [1, 3, 1]]).astype('int64')
mmap_lod = [[1, 1]]
loss_weight = np.array([[1], [1]]).astype('double')
self.inputs = {
'Ins': (x1, x1_lod),
'Ins_tag': (x2, x2_lod),
'Filter_tag': x3,
}
self.outputs = {
'Out': (out, out_lod),
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': True, }
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Ins'], 'Out', no_grad_set=set(['Ins_tag', 'Filter_tag']))
"""This is Test Case 3"""
class TestFilterByInstagOp3(OpTest):
def setUp(self):
self.op_type = 'filter_by_instag'
batch_size = 4
x1_embed_size = 4
fc_cnt = 2
x1 = np.array([[10, 13, 12, 1], [1, 1, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1]]).astype('double')
x1_lod = [[1, 1, 1, 1]]
x2 = np.array([[2], [1], [2], [1]]).astype('int64')
x2_lod = [[1, 1, 1, 1]]
x3 = np.array([3]).astype('int64')
out = np.array([[0, 0, 0, 0]]).astype('double')
out_lod = [[1]]
mmap = np.array([[0, 1, 1]]).astype('int64')
mmap_lod = [[1]]
loss_weight = np.array([[0]]).astype('double')
self.inputs = {
'Ins': (x1, x1_lod),
'Ins_tag': (x2, x2_lod),
'Filter_tag': x3,
}
self.outputs = {
'Out': (out, out_lod),
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': True, }
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Ins'], 'Out', no_grad_set=set(['Ins_tag', 'Filter_tag']))
"""This is Test Case 4"""
class TestFilterByInstagOp4(OpTest):
def setUp(self):
self.op_type = 'filter_by_instag'
batch_size = 4
x1_embed_size = 4
fc_cnt = 2
x1 = np.array([[10, 13, 12, 1], [1, 1, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1]]).astype('double')
x2 = np.array([[2], [1], [2], [1]]).astype('int64')
x2_lod = [[1, 1, 1, 1]]
x3 = np.array([3]).astype('int64')
out = np.array([[0, 0, 0, 0]]).astype('double')
out_lod = [[1]]
mmap = np.array([[0, 1, 1]]).astype('int64')
mmap_lod = [[1]]
loss_weight = np.array([[0]]).astype('double')
self.inputs = {
'Ins': x1,
'Ins_tag': (x2, x2_lod),
'Filter_tag': x3,
}
self.outputs = {
'Out': (out, out_lod),
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': False, }
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Ins'], 'Out', no_grad_set=set(['Ins_tag', 'Filter_tag']))
if __name__ == '__main__':
unittest.main()
......@@ -1896,14 +1896,6 @@ class TestBook(LayerTest):
self.assertTrue(z.lod_level == 1)
return z
def test_lod_append(self):
with self.static_graph():
x = layers.data(
name='x', shape=[6, 10], dtype='float32', lod_level=1)
y = layers.lod_append(x, [1, 1, 1, 1, 1, 1])
self.assertTrue(y.lod_level == 1)
return y
def test_affine_grid(self):
with self.static_graph():
data = layers.data(name='data', shape=[2, 3, 3], dtype="float32")
......@@ -1999,6 +1991,26 @@ class TestBook(LayerTest):
input=seqs, offset=offset, length=length)
return (out)
def test_filter_by_instag(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
x1 = layers.data(
name='Ins', shape=[32, 1], dtype='float32', lod_level=0)
x2 = layers.data(
name='Ins_tag',
shape=[32, 1],
dtype='int64',
lod_level=0,
stop_gradient=True)
x3 = layers.create_global_var(
shape=[1, 1],
value=20,
dtype='int64',
persistable=True,
force_cpu=True,
name='Filter_tag')
out1, out2 = layers.filter_by_instag(x1, x2, x3, is_lod=True)
def test_roi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册