提交 397de907 编写于 作者: N nhzlx

merge develops

test=develop
......@@ -30,66 +30,61 @@ UNSET_VAR(PROTOBUF_LITE_LIBRARY)
UNSET_VAR(PROTOBUF_LIBRARY)
UNSET_VAR(PROTOBUF_INCLUDE_DIR)
UNSET_VAR(Protobuf_PROTOC_EXECUTABLE)
function(protobuf_generate_python SRCS)
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
if(NOT ARGN)
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
return()
endif()
if(NOT COMMAND protobuf_generate_python) # before cmake 3.4, protobuf_genrerate_python is not defined.
function(protobuf_generate_python SRCS)
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
if(NOT ARGN)
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
return()
endif()
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
# Create an include path for each file specified
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
else()
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
endif()
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
endif()
if(DEFINED Protobuf_IMPORT_DIRS)
foreach(DIR ${Protobuf_IMPORT_DIRS})
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
endif()
set(${SRCS})
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
# Create an include path for each file specified
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
else()
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
endif()
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
endif()
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py"
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${Protobuf_PROTOC_EXECUTABLE}
COMMENT "Running Python protocol buffer compiler on ${FIL}"
VERBATIM )
if(DEFINED Protobuf_IMPORT_DIRS)
foreach(DIR ${Protobuf_IMPORT_DIRS})
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
endif()
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
endif()
set(${SRCS})
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif()
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${PROTOBUF_PROTOC_EXECUTABLE}
COMMENT "Running Python protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
# Print and set the protobuf library information,
# finish this cmake process and exit from this file.
......@@ -126,6 +121,7 @@ macro(PROMPT_PROTOBUF_LIB)
# FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`.
# make `protobuf_generate_cpp` happy.
SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE})
FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep})
......
......@@ -179,6 +179,7 @@ paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], vara
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
......@@ -201,6 +202,7 @@ paddle.fluid.layers.create_tensor ArgSpec(args=['dtype', 'name', 'persistable'],
paddle.fluid.layers.create_parameter ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.create_global_var ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None))
paddle.fluid.layers.cast ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.tensor_array_to_tensor ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.concat ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.sums ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.assign ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,))
......
......@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, var_name, true) << "]";
ss << "(" << GetLoD(*scope, var_name) << ")";
}
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
......@@ -24,5 +27,27 @@ class VarTypeInference {
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
}
protected:
virtual std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const = 0;
};
} // namespace framework
} // namespace paddle
......@@ -113,7 +113,9 @@ void Analyzer::Run(Argument* argument) {
passes.push_back("infer_clean_graph_pass");
passes.push_back("graph_viz_pass"); // add graphviz for debug.
for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) {
// skip mkldnn pass when use_mkldnn_ = false;
bool skip_pass = (!use_mkldnn_) && pass.find("mkldnn") != std::string::npos;
if (!disabled_ir_passes_.count(pass) && !skip_pass) {
passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug.
}
......
......@@ -317,6 +317,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(tensor_array_to_tensor_op DEPS concat_op)
op_library(concat_op DEPS concat_and_split)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......
......@@ -91,16 +91,12 @@ class ActivationOp : public framework::OperatorWithKernel {
}
};
class ActivationOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
class ActivationOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
......
......@@ -170,6 +170,15 @@ The required data format for this layer is one of the following:
}
};
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
template <typename T>
class BatchNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
......@@ -525,7 +534,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormGradMaker);
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(
......
......@@ -224,6 +224,15 @@ $$
)DOC");
}
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
{"Input", /*->*/ "Output"}};
}
};
void Conv3DOpMaker::Make() {
AddInput(
"Input",
......@@ -365,6 +374,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
......@@ -372,7 +382,9 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -179,6 +180,15 @@ or not. But the output only shares the LoD information with input X.
)DOC");
}
};
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
} // namespace operators
} // namespace paddle
......@@ -186,6 +196,7 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
......
......@@ -75,16 +75,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
}
};
class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto &x = block->FindRecursiveOrCreateVar(x_name);
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
class ElementwiseOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X.
}
};
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MeanGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MeanGradMaker : public framework::SingleGradOpDescMaker {
......@@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -126,6 +126,14 @@ or not. But the output only shares the LoD information with input $X$.
}
};
class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MulGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -178,7 +186,8 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGradMaker);
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType,
ops::MulOpGradMaker);
REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
return output_size;
}
void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
......@@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
}
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
......@@ -104,7 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
layout_, library_);
}
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
......@@ -112,7 +112,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
}
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
......@@ -262,6 +262,14 @@ Example:
)DOC");
}
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
void Pool3dOpMaker::Make() {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
......@@ -372,6 +380,7 @@ Example:
namespace ops = paddle::operators;
REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker,
ops::PoolOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad);
......@@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL(
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker,
ops::PoolOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/similarity_focus_op.h"
namespace paddle {
namespace operators {
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
" [BatchSize, X, Y, Z]");
AddOutput("Out",
"(Tensor, default Tensor<float>), the similarity focus mask"
" with the same shape of input X.");
AddAttr<int>("axis",
"(int32), indicating the dimension to be select. It can"
" only be 1, 2, or 3.");
AddAttr<std::vector<int>>("indexes",
"(std::vector<int32>), indicating the indexes"
" of the selected dimension.");
AddComment(R"DOC(
SimilarityFocus Operator.
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
2. For each index, find the largest numbers in the tensor T, so that the same
row and same column has at most one number(what it means is that if the
largest number has been found in the i-th row and the j-th column, then
the numbers in the i-th row or j-th column will be skipped. And then the
next largest number will be selected from the remaining numbers. Obviously
there will be min(B, C) numbers), and mark the corresponding position of the
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
)DOC");
}
};
class SimilarityFocusOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
ops::SimilarityFocusOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
ops::SimilarityFocusKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstring>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SimilarityFocusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Tensor* out = context.Output<Tensor>("Out");
const Tensor* x = context.Input<Tensor>("X");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int axis = context.Attr<int>("axis");
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
int64_t batch_size = x->dims()[0];
int64_t dim[4];
for (int i = 1; i <= 3; ++i) {
dim[i] = x->dims()[i];
}
if (indexes.size() < 1) {
PADDLE_THROW("Indexes' size can not be 0.");
}
for (auto index : indexes) {
if (dim[axis] < index) {
PADDLE_THROW("Index exceeds tensor shape limit.");
}
}
int64_t array_size = 1;
for (int i = 1; i <= 3; ++i) {
if (i != axis) {
array_size *= dim[i];
}
}
std::vector<std::pair<T, int64_t>> array(array_size);
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
return x.first > y.first;
};
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
int64_t* dim, int d1, int d2, int d3, int d4) {
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
d3 * dim[3] + d4;
};
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
for (int i = 0; i < batch_size; ++i) {
for (auto index : indexes) {
if (axis == 1) {
for (int j = 0; j < dim[2]; ++j) {
for (int k = 0; k < dim[3]; ++k) {
array[j * dim[3] + k] = std::make_pair(
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
for (auto x : array) {
int idx2 = x.second / dim[3];
int idx3 = x.second % dim[3];
if (tag2[idx2] || tag3[idx3]) {
continue;
}
tag_num++;
tag2[idx2] = true;
tag3[idx3] = true;
for (int j = 0; j < dim[1]; ++j) {
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
}
if (tag_num == std::min(dim[2], dim[3])) {
break;
}
}
} else if (axis == 2) {
for (int j = 0; j < dim[1]; ++j) {
for (int k = 0; k < dim[3]; ++k) {
array[j * dim[3] + k] = std::make_pair(
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
for (auto x : array) {
int idx1 = x.second / dim[3];
int idx3 = x.second % dim[3];
if (tag1[idx1] || tag3[idx3]) {
continue;
}
tag_num++;
tag1[idx1] = true;
tag3[idx3] = true;
for (int j = 0; j < dim[2]; ++j) {
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
}
if (tag_num == std::min(dim[1], dim[3])) {
break;
}
}
} else if (axis == 3) {
for (int j = 0; j < dim[1]; ++j) {
for (int k = 0; k < dim[2]; ++k) {
array[j * dim[2] + k] = std::make_pair(
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
}
}
std::sort(array.begin(), array.end(), cmp);
int tag_num = 0;
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
for (auto x : array) {
int idx1 = x.second / dim[2];
int idx2 = x.second % dim[2];
if (tag1[idx1] || tag2[idx2]) {
continue;
}
tag_num++;
tag1[idx1] = true;
tag2[idx2] = true;
for (int j = 0; j < dim[3]; ++j) {
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
}
if (tag_num == std::min(dim[1], dim[2])) {
break;
}
}
} else {
PADDLE_THROW("Axis must be 1 or 2 or 3");
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -124,6 +124,14 @@ For each row $i$ and each column $j$ in the matrix, we have:
}
};
class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -196,7 +204,7 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpGradMaker);
ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
using framework::Tensor;
void LodTensorArray2LodTensorVector(const framework::Scope &scope,
const std::string &base_name,
const std::string &lod_tensor_array_name,
std::vector<std::string> *res_names) {
auto &inx =
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = base_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
feed_input.ShareDataWith(inx[i]);
res_names->push_back(var_name);
}
}
void LodTensorVectorResizeFromLodTensorArray(
const framework::Scope &scope, const std::string &base_name,
const std::string &lod_tensor_array_name,
std::vector<std::string> *res_names) {
auto &inx =
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = base_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
auto dims = inx[i].dims();
feed_input.Resize(dims);
res_names->push_back(var_name);
}
}
void LodTensorArrayCreateFromLodTensorArray(
const framework::Scope &scope,
const std::string &input_lod_tensor_array_name,
const std::string &output_lod_tensor_array_name) {
auto &inx = scope.FindVar(input_lod_tensor_array_name)
->Get<framework::LoDTensorArray>();
auto &grad_inx = *scope.FindVar(output_lod_tensor_array_name)
->GetMutable<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = output_lod_tensor_array_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
grad_inx.push_back(feed_input);
}
}
class LoDTensorArray2TensorOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto axis = Attr<int>("axis");
framework::AttributeMap attrs;
attrs["axis"] = axis;
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto &out_inx =
*scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
// get the input tensorarray items' dim in out_inx
auto out_inx_dim = out_inx.dims();
out_inx_dim[0] = inx.size();
out_inx.Resize(out_inx_dim);
std::string var_name = "out_index";
framework::Variable *tmp_index_var =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &tmp_index_tensor =
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
tmp_index_tensor.Resize(out_inx_dim);
int *tmp_index_data =
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
auto out_dims = inx[0].dims();
size_t out_dim_sum = 0;
for (size_t index = 0; index < inx.size(); index++) {
auto inx_dims = inx[index].dims();
out_dim_sum += inx_dims[axis];
tmp_index_data[index] = inx_dims[axis];
}
out_inx.ShareDataWith(tmp_index_tensor);
// get input array items' dims
out_dims[axis] = out_dim_sum;
out.Resize(out_dims);
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// Invoke Reshape Op
auto concat_op = framework::OpRegistry::CreateOp(
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
concat_op->Run(scope, place);
}
};
class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input LoDTensorArray of tensor_array_to_tensor operator.");
AddOutput("Out", "Output tensor of tensor_array_to_tensor operator.");
AddOutput("OutIndex",
"Output input LoDTensorArray items' dims of "
"tensor_array_to_tensor operator.");
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
.SetDefault(0);
AddComment(R"DOC(
tensor_array_to_tensor Operator.
Concatenate the input LoDTensorArray along dimension axis to the output Tensor.
Examples:
Input = {[1,2], [3,4], [5,6]}
axis = 0
Output = [[1,2],
[3,4],
[5,6]]
OutputIndex = [1,1,1]
)DOC");
}
};
class LoDTensorArray2TensorOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {}
};
class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
};
class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto axis = Attr<int>("axis");
framework::AttributeMap attrs;
attrs["axis"] = axis;
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// grad
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
std::vector<std::string> grad_names;
LodTensorVectorResizeFromLodTensorArray(scope, "grad_name", Input("X"),
&grad_names);
auto concat_grad_op = framework::OpRegistry::CreateOp(
"concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs);
concat_grad_op->Run(scope, place);
LodTensorArrayCreateFromLodTensorArray(scope, Input("X"), dx_name);
auto &grad_inx =
*scope.FindVar(dx_name)->GetMutable<framework::LoDTensorArray>();
for (size_t i = 0; i < grad_names.size(); i++) {
std::string var_name = grad_names[i];
auto &feed_input = scope.FindVar(var_name)->Get<framework::LoDTensor>();
grad_inx[i].ShareDataWith(feed_input);
}
}
};
} // namespace operators
} // namespace paddle
USE_OP(concat);
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensor_array_to_tensor, ops::LoDTensorArray2TensorOp,
ops::LoDTensorArray2TensorOpMaker,
ops::LoDTensorArray2TensorOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(tensor_array_to_tensor_grad, ops::LoDTensorArray2TensorGradOp,
ops::LoDTensorArray2TensorGradInferShape,
ops::LoDTensorArray2TensorGradInferVarType);
......@@ -34,6 +34,7 @@ from . import regularizer
from . import average
from . import metrics
from . import transpiler
from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LOOKUP_TABLE_TYPE = "lookup_table"
def find_distributed_lookup_table(program):
"""
Find distribute lookup table in program.
We only support one distribute table now.
:param program:
:return: table_name or None
"""
table_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if table_name is None:
table_name = op.input("W")[0]
if table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
else:
if table_name is not None:
assert op.input("W")[0] != table_name
return table_name
......@@ -160,6 +160,7 @@ __all__ = [
'affine_grid',
'sequence_reverse',
'affine_channel',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
......@@ -7933,6 +7934,118 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
return out
def similarity_focus(input, axis, indexes, name=None):
"""
SimilarityFocus Operator
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
2. For each index, find the largest numbers in the tensor T, so that the same
row and same column has at most one number(what it means is that if the
largest number has been found in the i-th row and the j-th column, then
the numbers in the i-th row or j-th column will be skipped. And then the
next largest number will be selected from the remaining numbers. Obviously
there will be min(B, C) numbers), and mark the corresponding position of the
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
.. code-block:: text
* Example :
Given a 4-D tensor x with the shape (BatchSize, C, A, B), where C is
the number of channels and the shape of feature map is (A, B):
x.shape = (2, 3, 2, 2)
x.data = [[[[0.8, 0.1],
[0.4, 0.5]],
[[0.9, 0.7],
[0.9, 0.9]],
[[0.8, 0.9],
[0.1, 0.2]]],
[[[0.2, 0.5],
[0.3, 0.4]],
[[0.9, 0.7],
[0.8, 0.4]],
[[0.0, 0.2],
[0.4, 0.7]]]]
Given axis: 1 (the axis of the channel)
Given indexes: [0]
then we get a 4-D tensor out with the same shape of input x:
out.shape = (2, 3, 2, 2)
out.data = [[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]]],
[[[0.0, 1.0],
[1.0, 0.0]],
[[0.0, 1.0],
[1.0, 0.0]],
[[0.0, 1.0],
[1.0, 0.0]]]]
Args:
input(Variable): The input tensor variable(default float). It should
be a 4-D tensor with shape [BatchSize, A, B, C].
axis(int): Indicating the dimension to be selected. It can only be
1, 2 or 3.
indexes(list): Indicating the indexes of the selected dimension.
Returns:
Variable: A tensor variable with the same shape and same type
as the input.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[2, 3, 2, 2], dtype='float32')
x = fluid.layers.layer_norm(input=data, axis=1, indexes=[0])
"""
helper = LayerHelper('similarity_focus', **locals())
# check attrs
if isinstance(axis, int) is False:
raise TypeError("axis must be int type.")
if isinstance(indexes, list) is False:
raise TypeError("indexes must be list type.")
if axis != 1 and axis != 2 and axis != 3:
raise ValueError("axis must be 1, 2 or 3.")
if len(indexes) == 0:
raise ValueError("indexes can not be empty.")
if name is None:
out = helper.create_variable_for_type_inference(dtype=input.dtype)
else:
out = helper.create_variable(
name=name, dtype=input.dtype, persistable=False)
helper.append_op(
type='similarity_focus',
inputs={'X': input},
outputs={'Out': out},
attrs={"axis": axis,
"indexes": indexes})
return out
def hash(input, hash_size, num_hash=1, name=None):
"""
Hash the input to an integer whose value is less than the given hash size.
......
......@@ -24,10 +24,10 @@ from .layer_function_generator import templatedoc
import numpy
__all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'concat',
'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant',
'argmin', 'argmax', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf',
'has_nan', 'isfinite'
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite'
]
......@@ -193,6 +193,60 @@ def concat(input, axis=0, name=None):
return out
def tensor_array_to_tensor(input, axis=1, name=None):
"""
This function concatenates the input LodTensorArray along the axis mentioned
and returns that as the output.
A simple example as below:
.. code-block:: text
Given:
input.data = {[[0.6, 0.1, 0.3],
[0.5, 0.3, 0.2]],
[[1.3],
[1.8]],
[[2.3, 2.1],
[2.5, 2.4]]}
axis = 1
Then:
output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1],
[0.5, 0.3, 0.2, 1.8, 2.5, 2.4]]
output_index.data = [3, 1, 2]
Args:
input(list): Input LodTensorArray
axis(int): Integer axis along which the tensors will be concatenated
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: Output variable of the concatenation
Variable: The input LodTensorArray items' dims along the axis
Examples:
.. code-block:: python
output, output_index = fluid.layers.tensor_array_to_tensor(input=tensor_array)
"""
helper = LayerHelper('tensor_array_concat', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(
type='tensor_array_concat',
inputs={'X': input},
outputs={'Out': [out],
'OutIndex': [out_index]},
attrs={'axis': axis})
return out, out_index
def sums(input, out=None):
"""
This function performs the sum operation on the input and returns the
......
......@@ -13,21 +13,23 @@
# limitations under the License.
from __future__ import print_function
import re
import sys
from collections import defaultdict
from contextlib import contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from . import framework
from . import layers
from . import unique_name
from .backward import append_backward
from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard
from . import unique_name
from .initializer import Constant
from .layer_helper import LayerHelper
from .regularizer import append_regularization_ops
from .clip import append_gradient_clip_ops, error_clip_callback
from contextlib import contextmanager
from .layers import ops
from .regularizer import append_regularization_ops
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
......@@ -85,7 +87,7 @@ class Optimizer(object):
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype == None else self._dtype,
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
def _global_learning_rate(self, program=None):
......@@ -245,6 +247,50 @@ class Optimizer(object):
end = len(global_block.ops)
return global_block._slice_ops(start, end)
def _process_distribute_lookuptable(self, param_grads, loss,
startup_program):
"""
Because distribute lookup table only support SGD optimizer for now, not support
other optimizer and regularization, so we should find the table parameter out,
and avoid to add regularization and other op for it, and add sgd optimize op
for it independently.
:param param_grads(list((Var, Var))): list of (param, grad) pair.
:param loss: the loss variable.
:param startup_program: the startup program
"""
program = loss.block.program
table_name = find_distributed_lookup_table(program)
table_param = None
table_grad = None
new_param_grads = []
for p, g in param_grads:
if p.name == table_name:
if table_param is not None:
raise RuntimeError(
"multi dist table var found, only support one now!")
table_param = p
table_grad = g
else:
new_param_grads.append((p, g))
sgd_op = None
if table_param is not None:
with program_guard(program, startup_program):
param_and_grad = [table_param, table_grad]
with table_param.block.program._optimized_guard(param_and_grad), \
framework.name_scope("optimizer"):
self._create_global_learning_rate()
# create the optimize op
sgd_op = loss.block.append_op(
type='sgd',
inputs={
"Param": table_param,
"Grad": table_grad,
"LearningRate":
self._create_param_lr(param_and_grad)
},
outputs={"ParamOut": param_and_grad[0]})
return new_param_grads, (table_param, table_grad), sgd_op
def minimize(self,
loss,
startup_program=None,
......@@ -260,6 +306,9 @@ class Optimizer(object):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads, loss, startup_program)
params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any
......@@ -268,6 +317,9 @@ class Optimizer(object):
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
return optimize_ops, params_grads
......
......@@ -38,7 +38,7 @@ depth = 8
mix_hidden_lr = 1e-3
IS_SPARSE = True
PASS_NUM = 10
PASS_NUM = 1
BATCH_SIZE = 10
embedding_name = 'emb'
......
......@@ -567,7 +567,6 @@ class TestDistLookupTable(TestDistLookupTableBase):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
......@@ -639,7 +638,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config)
trainer, trainer_startup = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1)
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
......@@ -653,6 +652,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
'recv', 'concat'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestDistLookupTableSliceSize(TestDistLookupTableBase):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
class TestSimilarityFocusOp(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 2
x_dim, y_dim, z_dim = 3, 2, 2
self.inputs = {
'X': np.array([[[[0.8, 0.1], [0.4, 0.5]], [[0.9, 0.7], [0.9, 0.9]],
[[0.8, 0.9], [0.1, 0.2]]],
[[[0.2, 0.5], [0.3, 0.4]], [[0.9, 0.7], [0.8, 0.4]],
[[0.0, 0.2], [0.4, 0.7]]]]),
}
self.attrs = {
'axis': 1,
'indexes': [0],
}
output = None
for batch in range(batch_size):
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
)
tag1 = [0 for i in range(y_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(y_dim, z_dim):
break
channel[index] = -1
res = res.reshape(1, y_dim, z_dim).repeat([x_dim], axis=0)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis1(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 3
x_dim, y_dim, z_dim = 4, 5, 6
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 1,
'indexes': [0, 3],
}
output = None
for batch in range(batch_size):
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
)
tag1 = [0 for i in range(y_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(y_dim, z_dim):
break
channel[index] = -1
res = res.reshape(1, y_dim, z_dim)
res = res.repeat([x_dim], axis=0)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis2(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 6
x_dim, y_dim, z_dim = 7, 8, 9
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 2,
'indexes': [0, 3, 5],
}
output = None
for batch in range(batch_size):
res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy(
)
tag1 = [0 for i in range(x_dim)]
tag2 = [0 for i in range(z_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // z_dim
idx2 = index % z_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(x_dim, z_dim):
break
channel[index] = -1
res = res.reshape(x_dim, 1, z_dim)
res = res.repeat([y_dim], axis=1)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestSimilarityFocusOp_axis3(OpTest):
def setUp(self):
self.op_type = "similarity_focus"
batch_size = 64
x_dim, y_dim, z_dim = 48, 48, 13
self.inputs = {
'X': np.random.random(
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
}
self.attrs = {
'axis': 3,
'indexes': [0, 2, 7, 9],
}
output = None
for batch in range(batch_size):
res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1)
for index in self.attrs['indexes']:
channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy(
)
tag1 = [0 for i in range(x_dim)]
tag2 = [0 for i in range(y_dim)]
cnt = 0
for i in range(channel.size):
index = channel.argmax()
idx1 = index // y_dim
idx2 = index % y_dim
if tag1[idx1] + tag2[idx2] == 0:
tag1[idx1] = 1
tag2[idx2] = 1
res[index] = 1
cnt += 1
if cnt == min(x_dim, y_dim):
break
channel[index] = -1
res = res.reshape(x_dim, y_dim, 1)
res = res.repeat([z_dim], axis=2)
res = res.reshape(1, x_dim, y_dim, z_dim)
if output is not None:
output = np.concatenate((output, res), axis=0)
else:
output = res
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
class TestLoDTensorArrayConcat(unittest.TestCase):
def setUp(self):
self.op_type = "tensor_array_to_tensor"
self.attrs = {"axis": 0}
self.outputs = ["Out"]
def test_get_set(self):
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
input_arr = block.create_var(
name="tmp_lod_tensor_array",
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
input_arr.persistable = True
input_arr_var = scope.var('tmp_lod_tensor_array')
input_tensor_array = input_arr_var.get_lod_tensor_array()
self.assertEqual(0, len(input_tensor_array))
cpu = core.CPUPlace()
for i in range(10):
t = core.LoDTensor()
if i == 0:
t.set(numpy.array([[i], [i]], dtype='float32'), cpu)
else:
t.set(numpy.array([[i]], dtype='float32'), cpu)
input_tensor_array.append(t)
self.assertEqual(10, len(input_tensor_array))
random_grad = numpy.random.random_sample([11]).astype(numpy.float32)
y_out = block.create_var(name="Out")
y_out.persistable = True
y_out_index = block.create_var(name="OutIndex")
y_out_index.persistable = True
y_grad_arr = block.create_var(
name='Out@GRAD', dtype='float32', shape=[11])
y_grad_arr.persistable = True
y_grad = scope.var('Out@GRAD')
y_grad_tensor = y_grad.get_tensor()
y_grad_tensor.set(random_grad, cpu)
op = block.append_op(
type=self.op_type,
inputs={"X": input_arr},
outputs={"Out": y_out,
"OutIndex": y_out_index},
attrs=self.attrs)
out_grad = block.create_var(
name="tmp_lod_tensor_array@GRAD",
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
out_grad.persistable = True
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
fetch_list = []
fetch_list.append(block.var('Out'))
fetch_list.append(block.var('OutIndex'))
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(program, fetch_list=fetch_list, scope=scope)
#print ("index: ", numpy.array(out[1]))
# test forward
tensor_res = numpy.array(out[0])
tensor_res_out_idx = numpy.array(out[1])
tensor_gt = numpy.array(
[0] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float32')
self.assertEqual(len(tensor_res), len(tensor_gt))
self.assertEqual(len(tensor_res_out_idx), 10)
for i in range(len(tensor_res)):
self.assertEqual(tensor_res[i], tensor_gt[i])
for i in range(len(tensor_res_out_idx)):
if i == 0:
self.assertEqual(tensor_res_out_idx[i], 2)
else:
self.assertEqual(tensor_res_out_idx[i], 1)
# test backward
grad_tensor = scope.var('tmp_lod_tensor_array@GRAD')
grad_tensor_array = grad_tensor.get_lod_tensor_array()
self.assertEqual(10, len(grad_tensor_array))
for i in range(len(grad_tensor_array)):
if i == 0:
self.assertEqual(
numpy.array(grad_tensor_array[i])[0],
numpy.array(random_grad[i]))
self.assertEqual(
numpy.array(grad_tensor_array[i])[1],
numpy.array(random_grad[i + 1]))
if i == 1:
self.assertEqual(
numpy.array(grad_tensor_array[i]),
numpy.array(random_grad[i + 1]))
if __name__ == '__main__':
unittest.main()
......@@ -31,18 +31,17 @@ Steps to transpile pserver:
"""
import math
import sys
import numpy as np
import collections
import six
import logging
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \
default_startup_program, Block, \
Parameter, grad_var_name
from .details import *
from ..distribute_lookup_table import find_distributed_lookup_table
from functools import reduce
LOOKUP_TABLE_TYPE = "lookup_table"
......@@ -292,7 +291,8 @@ class DistributeTranspiler(object):
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
self.table_name = find_distributed_lookup_table(self.origin_program)
self.has_distributed_lookup_table = self.table_name != None
self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads:
......@@ -966,28 +966,6 @@ to transpile() call.")
# ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
return len(distributed_lookup_table_ops) > 0
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
......@@ -1341,7 +1319,6 @@ to transpile() call.")
"""
create a new block to handle save checkpoint.
"""
import os
pserver_program.global_block().create_var(
name="kLookupTablePath",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册