未验证 提交 1ce478f1 编写于 作者: Y yuyang18

Polish reshape op

上级 81f22bb2
......@@ -76,6 +76,19 @@ class OpRegistry {
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
template <typename PlaceType, typename T, typename KernelType>
inline void RegisterKernelClass(const char* op_type, const char* library_type) {
std::string library(library_type);
std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") {
data_layout = "MKLDNNLAYOUT";
}
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout),
StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType());
}
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
......@@ -83,16 +96,7 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
std::string library(library_type);
std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") {
data_layout = "MKLDNNLAYOUT";
}
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout),
StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
......@@ -116,6 +120,47 @@ class OpKernelRegistrar : public Registrar {
}
};
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctorEx;
template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar {
public:
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
func;
func(op_type, library_type);
}
};
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
DataTypeAndKernelType...> {
void operator()(const char* op_type, const char* library_type) const {}
};
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
DataTypeAndKernelType...> {
using KERNEL_TYPE =
typename std::tuple_element<I + 1,
std::tuple<DataTypeAndKernelType...>>::type;
using T =
typename std::tuple_element<I,
std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const {
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
DataTypeAndKernelType...>
func;
func(op_type, library_type);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
......@@ -174,6 +219,25 @@ class OpKernelRegistrar : public Registrar {
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
......
......@@ -107,19 +107,75 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
}
};
void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
framework::DDim out_dims = out->dims();
if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
}
if (!in->lod().empty()) {
PADDLE_ENFORCE_EQ(out_dims[0], in->dims()[0],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op.");
}
bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) {
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopySync(*in, ctx.GetPlace(), out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*in);
out->Resize(out_dims);
}
}
void ReshapeGradKernelBase::Compute(
const framework::ExecutionContext &ctx) const {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data(ctx.GetPlace(), d_out->type());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel<CPU, float>,
ops::ReshapeKernel<CPU, double>,
ops::ReshapeKernel<CPU, int>,
ops::ReshapeKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<CPU, float>,
ops::ReshapeGradKernel<CPU, double>,
ops::ReshapeGradKernel<CPU, int>,
ops::ReshapeGradKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<float>,
ops::ReshapeGradKernel<double>,
ops::ReshapeGradKernel<int>,
ops::ReshapeGradKernel<int64_t>);
......@@ -13,14 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel<CUDA, float>,
paddle::operators::ReshapeKernel<CUDA, double>,
paddle::operators::ReshapeKernel<CUDA, int>,
paddle::operators::ReshapeKernel<CUDA, int64_t>);
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL(reshape_grad,
paddle::operators::ReshapeGradKernel<CUDA, float>,
paddle::operators::ReshapeGradKernel<CUDA, double>,
paddle::operators::ReshapeGradKernel<CUDA, int>,
paddle::operators::ReshapeGradKernel<CUDA, int64_t>);
paddle::operators::ReshapeGradKernel<float>,
paddle::operators::ReshapeGradKernel<double>,
paddle::operators::ReshapeGradKernel<int>,
paddle::operators::ReshapeGradKernel<int64_t>);
......@@ -118,72 +118,21 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
};
template <typename DeviceContext, typename T>
class ReshapeKernel : public framework::OpKernel<T> {
class ReshapeKernel : public framework::OpKernelBase {
public:
void Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
framework::DDim out_dims = out->dims();
if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
}
if (!in->lod().empty()) {
PADDLE_ENFORCE_EQ(
out_dims[0], in->dims()[0],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op.");
}
void Compute(const framework::ExecutionContext &ctx) const final;
};
bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) {
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopySync(*in, ctx.GetPlace(), out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*in);
out->Resize(out_dims);
}
}
class ReshapeGradKernelBase : public framework::OpKernelBase {
public:
void Compute(const framework::ExecutionContext &ctx) const;
};
template <typename DeviceContext, typename T>
class ReshapeGradKernel : public framework::OpKernel<T> {
template <typename T>
class ReshapeGradKernel : public ReshapeGradKernelBase {
public:
void Compute(const framework::ExecutionContext &ctx) const {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
}
// Tell register element type.
using ELEMENT_TYPE = T;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册