提交 9b558a80 编写于 作者: P peizhilin

Merge branch 'windows/build' into windows/online

test=develop
......@@ -72,49 +72,21 @@ IF(NOT ${CBLAS_FOUND})
ENDIF()
ENDIF()
IF(WIN32)
ExternalProject_Add(
extern_openblas
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
GIT_TAG ${OPENBLAS_COMMIT}
PREFIX ${CBLAS_SOURCES_DIR}
INSTALL_DIR ${CBLAS_INSTALL_DIR}
BUILD_IN_SOURCE 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DNO_SHARED=ON
-DNO_STATIC=OFF
-DBUILD_WITHOUT_LAPACK=ON
-DUSE_THREAD=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${CBLAS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
ELSE(WIN32)
SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs)
ExternalProject_Add(
extern_openblas
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
GIT_TAG ${OPENBLAS_COMMIT}
PREFIX ${CBLAS_SOURCES_DIR}
INSTALL_DIR ${CBLAS_INSTALL_DIR}
BUILD_IN_SOURCE 1
BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS}
INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR>
&& rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
)
ENDIF(WIN32)
SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs)
ExternalProject_Add(
extern_openblas
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
GIT_TAG ${OPENBLAS_COMMIT}
PREFIX ${CBLAS_SOURCES_DIR}
INSTALL_DIR ${CBLAS_INSTALL_DIR}
BUILD_IN_SOURCE 1
BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS}
INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR>
&& rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
)
ENDIF (WITH_PREBUILD_OPENBLAS)
SET(CBLAS_PROVIDER openblas)
......
......@@ -179,6 +179,7 @@ paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], vara
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
......@@ -201,6 +202,7 @@ paddle.fluid.layers.create_tensor ArgSpec(args=['dtype', 'name', 'persistable'],
paddle.fluid.layers.create_parameter ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.create_global_var ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None))
paddle.fluid.layers.cast ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.tensor_array_to_tensor ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.concat ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.sums ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.assign ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,))
......
......@@ -321,6 +321,7 @@ op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat_and_split)
op_library(tensor_array_to_tensor_op DEPS concat_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/similarity_focus_op.h"
namespace paddle {
namespace operators {
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
" [BatchSize, X, Y, Z]");
AddOutput("Out",
"(Tensor, default Tensor<float>), the similarity focus mask"
" with the same shape of input X.");
AddAttr<int>("axis",
"(int32), indicating the dimension to be select. It can"
" only be 1, 2, or 3.");
AddAttr<std::vector<int>>("indexes",
"(std::vector<int32>), indicating the indexes"
" of the selected dimension.");
AddComment(R"DOC(
SimilarityFocus Operator.
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
2. For each index, find the largest numbers in the tensor T, so that the same
row and same column has at most one number(what it means is that if the
largest number has been found in the i-th row and the j-th column, then
the numbers in the i-th row or j-th column will be skipped. And then the
next largest number will be selected from the remaining numbers. Obviously
there will be min(B, C) numbers), and mark the corresponding position of the
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
)DOC");
}
};
class SimilarityFocusOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
ops::SimilarityFocusOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
ops::SimilarityFocusKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstring>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SimilarityFocusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Tensor* out = context.Output<Tensor>("Out");
const Tensor* x = context.Input<Tensor>("X");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int axis = context.Attr<int>("axis");
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
int64_t batch_size = x->dims()[0];
int64_t dim[4];
for (int i = 1; i <= 3; ++i) {
dim[i] = x->dims()[i];
}
if (indexes.size() < 1) {
PADDLE_THROW("Indexes' size can not be 0.");
}
for (auto index : indexes) {
if (dim[axis] < index) {
PADDLE_THROW("Index exceeds tensor shape limit.");
}
}
int64_t array_size = 1;
for (int i = 1; i <= 3; ++i) {
if (i != axis) {
array_size *= dim[i];
}
}
std::vector<std::pair<T, int64_t>> array(array_size);
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
return x.first > y.first;
};
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
int64_t* dim, int d1, int d2, int d3, int d4) {
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
d3 * dim[3] + d4;
};
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
for (int i = 0; i < batch_size; ++i) {
for (auto index : indexes) {
if (axis == 1) {
for (int j = 0; j < dim[2]; ++j) {
for (int k = 0; k < dim[3]; ++k) {
array[j * dim[3] + k] = std::make_pair(
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
for (auto x : array) {
int idx2 = x.second / dim[3];
int idx3 = x.second % dim[3];
if (tag2[idx2] || tag3[idx3]) {
continue;
}
tag_num++;
tag2[idx2] = true;
tag3[idx3] = true;
for (int j = 0; j < dim[1]; ++j) {
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
}
if (tag_num == std::min(dim[2], dim[3])) {
break;
}
}
} else if (axis == 2) {
for (int j = 0; j < dim[1]; ++j) {
for (int k = 0; k < dim[3]; ++k) {
array[j * dim[3] + k] = std::make_pair(
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
for (auto x : array) {
int idx1 = x.second / dim[3];
int idx3 = x.second % dim[3];
if (tag1[idx1] || tag3[idx3]) {
continue;
}
tag_num++;
tag1[idx1] = true;
tag3[idx3] = true;
for (int j = 0; j < dim[2]; ++j) {
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
}
if (tag_num == std::min(dim[1], dim[3])) {
break;
}
}
} else if (axis == 3) {
for (int j = 0; j < dim[1]; ++j) {
for (int k = 0; k < dim[2]; ++k) {
array[j * dim[2] + k] = std::make_pair(
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
for (auto x : array) {
int idx1 = x.second / dim[2];
int idx2 = x.second % dim[2];
if (tag1[idx1] || tag2[idx2]) {
continue;
}
tag_num++;
tag1[idx1] = true;
tag2[idx2] = true;
for (int j = 0; j < dim[3]; ++j) {
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
}
if (tag_num == std::min(dim[1], dim[2])) {
break;
}
}
} else {
PADDLE_THROW("Axis must be 1 or 2 or 3");
}
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
using framework::Tensor;
void LodTensorArray2LodTensorVector(const framework::Scope &scope,
const std::string &base_name,
const std::string &lod_tensor_array_name,
std::vector<std::string> *res_names) {
auto &inx =
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = base_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
feed_input.ShareDataWith(inx[i]);
res_names->push_back(var_name);
}
}
void LodTensorVectorResizeFromLodTensorArray(
const framework::Scope &scope, const std::string &base_name,
const std::string &lod_tensor_array_name,
std::vector<std::string> *res_names) {
auto &inx =
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = base_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
auto dims = inx[i].dims();
feed_input.Resize(dims);
res_names->push_back(var_name);
}
}
void LodTensorArrayCreateFromLodTensorArray(
const framework::Scope &scope,
const std::string &input_lod_tensor_array_name,
const std::string &output_lod_tensor_array_name) {
auto &inx = scope.FindVar(input_lod_tensor_array_name)
->Get<framework::LoDTensorArray>();
auto &grad_inx = *scope.FindVar(output_lod_tensor_array_name)
->GetMutable<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = output_lod_tensor_array_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
grad_inx.push_back(feed_input);
}
}
class LoDTensorArray2TensorOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto axis = Attr<int>("axis");
framework::AttributeMap attrs;
attrs["axis"] = axis;
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto &out_inx =
*scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
// get the input tensorarray items' dim in out_inx
auto out_inx_dim = out_inx.dims();
out_inx_dim[0] = inx.size();
out_inx.Resize(out_inx_dim);
std::string var_name = "out_index";
framework::Variable *tmp_index_var =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &tmp_index_tensor =
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
tmp_index_tensor.Resize(out_inx_dim);
int *tmp_index_data =
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
auto out_dims = inx[0].dims();
size_t out_dim_sum = 0;
for (size_t index = 0; index < inx.size(); index++) {
auto inx_dims = inx[index].dims();
out_dim_sum += inx_dims[axis];
tmp_index_data[index] = inx_dims[axis];
}
out_inx.ShareDataWith(tmp_index_tensor);
// get input array items' dims
out_dims[axis] = out_dim_sum;
out.Resize(out_dims);
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// Invoke Reshape Op
auto concat_op = framework::OpRegistry::CreateOp(
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
concat_op->Run(scope, place);
}
};
class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input LoDTensorArray of tensor_array_to_tensor operator.");
AddOutput("Out", "Output tensor of tensor_array_to_tensor operator.");
AddOutput("OutIndex",
"Output input LoDTensorArray items' dims of "
"tensor_array_to_tensor operator.");
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
.SetDefault(0);
AddComment(R"DOC(
tensor_array_to_tensor Operator.
Concatenate the input LoDTensorArray along dimension axis to the output Tensor.
Examples:
Input = {[1,2], [3,4], [5,6]}
axis = 0
Output = [[1,2],
[3,4],
[5,6]]
OutputIndex = [1,1,1]
)DOC");
}
};
class LoDTensorArray2TensorOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {}
};
class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
};
class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto axis = Attr<int>("axis");
framework::AttributeMap attrs;
attrs["axis"] = axis;
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// grad
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
std::vector<std::string> grad_names;
LodTensorVectorResizeFromLodTensorArray(scope, "grad_name", Input("X"),
&grad_names);
auto concat_grad_op = framework::OpRegistry::CreateOp(
"concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs);
concat_grad_op->Run(scope, place);
LodTensorArrayCreateFromLodTensorArray(scope, Input("X"), dx_name);
auto &grad_inx =
*scope.FindVar(dx_name)->GetMutable<framework::LoDTensorArray>();
for (size_t i = 0; i < grad_names.size(); i++) {
std::string var_name = grad_names[i];
auto &feed_input = scope.FindVar(var_name)->Get<framework::LoDTensor>();
grad_inx[i].ShareDataWith(feed_input);
}
}
};
} // namespace operators
} // namespace paddle
USE_OP(concat);
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensor_array_to_tensor, ops::LoDTensorArray2TensorOp,
ops::LoDTensorArray2TensorOpMaker,
ops::LoDTensorArray2TensorOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(tensor_array_to_tensor_grad, ops::LoDTensorArray2TensorGradOp,
ops::LoDTensorArray2TensorGradInferShape,
ops::LoDTensorArray2TensorGradInferVarType);
......@@ -44,7 +44,7 @@ limitations under the License. */
#include <boost/variant.hpp>
// some platform-independent defintion
#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__)
#if defined(_WIN32)
#define __UNUSED__()
#define __builtin_expect(EXP, C) (EXP)
#else
......
......@@ -161,6 +161,7 @@ __all__ = [
'affine_grid',
'sequence_reverse',
'affine_channel',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
......@@ -7937,6 +7938,118 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
return out
def similarity_focus(input, axis, indexes, name=None):
"""
SimilarityFocus Operator
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
2. For each index, find the largest numbers in the tensor T, so that the same
row and same column has at most one number(what it means is that if the
largest number has been found in the i-th row and the j-th column, then
the numbers in the i-th row or j-th column will be skipped. And then the
next largest number will be selected from the remaining numbers. Obviously
there will be min(B, C) numbers), and mark the corresponding position of the
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
.. code-block:: text
* Example :
Given a 4-D tensor x with the shape (BatchSize, C, A, B), where C is
the number of channels and the shape of feature map is (A, B):
x.shape = (2, 3, 2, 2)
x.data = [[[[0.8, 0.1],
[0.4, 0.5]],
[[0.9, 0.7],
[0.9, 0.9]],
[[0.8, 0.9],
[0.1, 0.2]]],
[[[0.2, 0.5],
[0.3, 0.4]],
[[0.9, 0.7],
[0.8, 0.4]],
[[0.0, 0.2],
[0.4, 0.7]]]]
Given axis: 1 (the axis of the channel)
Given indexes: [0]
then we get a 4-D tensor out with the same shape of input x:
out.shape = (2, 3, 2, 2)
out.data = [[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]]],
[[[0.0, 1.0],
[1.0, 0.0]],
[[0.0, 1.0],
[1.0, 0.0]],
[[0.0, 1.0],
[1.0, 0.0]]]]
Args:
input(Variable): The input tensor variable(default float). It should
be a 4-D tensor with shape [BatchSize, A, B, C].
axis(int): Indicating the dimension to be selected. It can only be
1, 2 or 3.
indexes(list): Indicating the indexes of the selected dimension.
Returns:
Variable: A tensor variable with the same shape and same type
as the input.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[2, 3, 2, 2], dtype='float32')
x = fluid.layers.layer_norm(input=data, axis=1, indexes=[0])
"""
helper = LayerHelper('similarity_focus', **locals())
# check attrs
if isinstance(axis, int) is False:
raise TypeError("axis must be int type.")
if isinstance(indexes, list) is False:
raise TypeError("indexes must be list type.")
if axis != 1 and axis != 2 and axis != 3:
raise ValueError("axis must be 1, 2 or 3.")
if len(indexes) == 0:
raise ValueError("indexes can not be empty.")
if name is None:
out = helper.create_variable_for_type_inference(dtype=input.dtype)
else:
out = helper.create_variable(
name=name, dtype=input.dtype, persistable=False)
helper.append_op(
type='similarity_focus',
inputs={'X': input},
outputs={'Out': out},
attrs={"axis": axis,
"indexes": indexes})
return out
def hash(input, hash_size, num_hash=1, name=None):
"""
Hash the input to an integer whose value is less than the given hash size.
......
......@@ -24,10 +24,10 @@ from .layer_function_generator import templatedoc
import numpy
__all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'concat',
'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant',
'argmin', 'argmax', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf',
'has_nan', 'isfinite'
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite'
]
......@@ -193,6 +193,60 @@ def concat(input, axis=0, name=None):
return out
def tensor_array_to_tensor(input, axis=1, name=None):
"""
This function concatenates the input LodTensorArray along the axis mentioned
and returns that as the output.
A simple example as below:
.. code-block:: text
Given:
input.data = {[[0.6, 0.1, 0.3],
[0.5, 0.3, 0.2]],
[[1.3],
[1.8]],
[[2.3, 2.1],
[2.5, 2.4]]}
axis = 1
Then:
output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1],
[0.5, 0.3, 0.2, 1.8, 2.5, 2.4]]
output_index.data = [3, 1, 2]
Args:
input(list): Input LodTensorArray
axis(int): Integer axis along which the tensors will be concatenated
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: Output variable of the concatenation
Variable: The input LodTensorArray items' dims along the axis
Examples:
.. code-block:: python
output, output_index = fluid.layers.tensor_array_to_tensor(input=tensor_array)
"""
helper = LayerHelper('tensor_array_concat', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(
type='tensor_array_concat',
inputs={'X': input},
outputs={'Out': [out],
'OutIndex': [out_index]},
attrs={'axis': axis})
return out, out_index
def sums(input, out=None):
"""
This function performs the sum operation on the input and returns the
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
class TestSimilarityFocusOp(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 2
x_dim, y_dim, z_dim = 3, 2, 2
self.inputs = {
'X': np.array([[[[0.8, 0.1], [0.4, 0.5]], [[0.9, 0.7], [0.9, 0.9]],
[[0.8, 0.9], [0.1, 0.2]]],
[[[0.2, 0.5], [0.3, 0.4]], [[0.9, 0.7], [0.8, 0.4]],
[[0.0, 0.2], [0.4, 0.7]]]]),
}
self.attrs = {
'axis': 1,
'indexes': [0],
}
output = None
for batch in range(batch_size):
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
)
tag1 = [0 for i in range(y_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(y_dim, z_dim):
break
channel[index] = -1
res = res.reshape(1, y_dim, z_dim).repeat([x_dim], axis=0)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis1(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 3
x_dim, y_dim, z_dim = 4, 5, 6
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 1,
'indexes': [0, 3],
}
output = None
for batch in range(batch_size):
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
)
tag1 = [0 for i in range(y_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(y_dim, z_dim):
break
channel[index] = -1
res = res.reshape(1, y_dim, z_dim)
res = res.repeat([x_dim], axis=0)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis2(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 6
x_dim, y_dim, z_dim = 7, 8, 9
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 2,
'indexes': [0, 3, 5],
}
output = None
for batch in range(batch_size):
res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy(
)
tag1 = [0 for i in range(x_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(x_dim, z_dim):
break
channel[index] = -1
res = res.reshape(x_dim, 1, z_dim)
res = res.repeat([y_dim], axis=1)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis3(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 64
x_dim, y_dim, z_dim = 48, 48, 13
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 3,
'indexes': [0, 2, 7, 9],
}
output = None
for batch in range(batch_size):
res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy(
)
tag1 = [0 for i in range(x_dim)]
tag2 = [0 for i in range(y_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // y_dim
idx2 = index % y_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(x_dim, y_dim):
break
channel[index] = -1
res = res.reshape(x_dim, y_dim, 1)
res = res.repeat([z_dim], axis=2)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
class TestLoDTensorArrayConcat(unittest.TestCase):
def setUp(self):
self.op_type = "tensor_array_to_tensor"
self.attrs = {"axis": 0}
self.outputs = ["Out"]
def test_get_set(self):
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
input_arr = block.create_var(
name="tmp_lod_tensor_array",
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
input_arr.persistable = True
input_arr_var = scope.var('tmp_lod_tensor_array')
input_tensor_array = input_arr_var.get_lod_tensor_array()
self.assertEqual(0, len(input_tensor_array))
cpu = core.CPUPlace()
for i in range(10):
t = core.LoDTensor()
if i == 0:
t.set(numpy.array([[i], [i]], dtype='float32'), cpu)
else:
t.set(numpy.array([[i]], dtype='float32'), cpu)
input_tensor_array.append(t)
self.assertEqual(10, len(input_tensor_array))
random_grad = numpy.random.random_sample([11]).astype(numpy.float32)
y_out = block.create_var(name="Out")
y_out.persistable = True
y_out_index = block.create_var(name="OutIndex")
y_out_index.persistable = True
y_grad_arr = block.create_var(
name='Out@GRAD', dtype='float32', shape=[11])
y_grad_arr.persistable = True
y_grad = scope.var('Out@GRAD')
y_grad_tensor = y_grad.get_tensor()
y_grad_tensor.set(random_grad, cpu)
op = block.append_op(
type=self.op_type,
inputs={"X": input_arr},
outputs={"Out": y_out,
"OutIndex": y_out_index},
attrs=self.attrs)
out_grad = block.create_var(
name="tmp_lod_tensor_array@GRAD",
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
out_grad.persistable = True
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
fetch_list = []
fetch_list.append(block.var('Out'))
fetch_list.append(block.var('OutIndex'))
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(program, fetch_list=fetch_list, scope=scope)
#print ("index: ", numpy.array(out[1]))
# test forward
tensor_res = numpy.array(out[0])
tensor_res_out_idx = numpy.array(out[1])
tensor_gt = numpy.array(
[0] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float32')
self.assertEqual(len(tensor_res), len(tensor_gt))
self.assertEqual(len(tensor_res_out_idx), 10)
for i in range(len(tensor_res)):
self.assertEqual(tensor_res[i], tensor_gt[i])
for i in range(len(tensor_res_out_idx)):
if i == 0:
self.assertEqual(tensor_res_out_idx[i], 2)
else:
self.assertEqual(tensor_res_out_idx[i], 1)
# test backward
grad_tensor = scope.var('tmp_lod_tensor_array@GRAD')
grad_tensor_array = grad_tensor.get_lod_tensor_array()
self.assertEqual(10, len(grad_tensor_array))
for i in range(len(grad_tensor_array)):
if i == 0:
self.assertEqual(
numpy.array(grad_tensor_array[i])[0],
numpy.array(random_grad[i]))
self.assertEqual(
numpy.array(grad_tensor_array[i])[1],
numpy.array(random_grad[i + 1]))
if i == 1:
self.assertEqual(
numpy.array(grad_tensor_array[i]),
numpy.array(random_grad[i + 1]))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册