未验证 提交 04025237 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Inplace] Automap inplace dtype and shape, support vector<Tensor> output (#52114)

* [CustomOP Inplace] Automap inplace dtype and shape, prepare for vector<Tensor> output

* delete dtype,shape func of multi_inplace op

* [CustomOP Inplace] Automap inplace dtype and shape, support vector<Tensor> output
上级 888a30c9
...@@ -268,15 +268,15 @@ static void RunKernelFunc( ...@@ -268,15 +268,15 @@ static void RunKernelFunc(
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i]; auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) { if (detail::IsDuplicableVar(out_name)) {
PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL, PADDLE_ENFORCE(
platform::errors::PreconditionNotMet( !inplace_map.empty() || (i == 0UL && outputs.size() == 1UL),
"If custom operator's outputs contains `paddle::Vec(" phi::errors::PreconditionNotMet(
")` type, " "If custom operator's outputs contains `paddle::Vec()` type "
"it only can hold one output.")); "without setting InplaceMap, it only can hold one output."));
auto vec_out = ctx.MultiOutput<phi::DenseTensor>(out_name); auto vec_out = ctx.MultiOutput<phi::DenseTensor>(out_name);
PADDLE_ENFORCE_NE(vec_out.empty(), PADDLE_ENFORCE_NE(vec_out.empty(),
true, true,
platform::errors::NotFound( phi::errors::NotFound(
"Output vector<tensor> (%s) is empty.", out_name)); "Output vector<tensor> (%s) is empty.", out_name));
std::vector<paddle::Tensor> custom_vec_out; std::vector<paddle::Tensor> custom_vec_out;
for (size_t j = 0; j < vec_out.size(); ++j) { for (size_t j = 0; j < vec_out.size(); ++j) {
...@@ -359,11 +359,67 @@ static void RunKernelFunc( ...@@ -359,11 +359,67 @@ static void RunKernelFunc(
} }
} }
static void RunInferShapeFunc(framework::InferShapeContext* ctx, static void RunDefaultInferShapeFunc(
framework::InferShapeContext* ctx,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map) {
if (inplace_map.empty()) { // general case, assure single input and output
PADDLE_ENFORCE_EQ(
inputs.size(),
1UL,
phi::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferShapeFn. "
"At this time, the input shape will be directly set to "
"the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
PADDLE_ENFORCE_EQ(
outputs.size(),
1UL,
phi::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferShapeFn. "
"At this time, the input shape will be directly set to "
"the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(3) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(inputs[0], outputs[0]);
} else { // inplace case
PADDLE_ENFORCE_EQ(
inplace_map.size(),
outputs.size(),
phi::errors::Unavailable(
"Your custom operator uses `SetInplaceMap` without setting the "
"InferShapeFn. However, `Outputs` size = %d does not match the "
"`InplaceMap` size = %d. Please check `SetInplaceMap` again or set "
"the InferShapeFn of custom operator by "
"`.SetInferShapeFn(PD_INFER_SHAPE(...)`)",
outputs.size(),
inplace_map.size()));
for (auto const& pair : inplace_map) {
if (detail::IsDuplicableVar(pair.first)) {
ctx->SetOutputsDim(pair.second, ctx->GetInputsDim(pair.first));
} else {
ctx->ShareDim(pair.first, pair.second);
}
}
}
}
static void RunInferShapeFunc(
framework::InferShapeContext* ctx,
const paddle::InferShapeFunc& func, const paddle::InferShapeFunc& func,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) { const std::vector<std::string>& attrs,
const std::unordered_map<std::string, std::string>& inplace_map,
const std::unordered_map<std::string, std::string>& inplace_reverse_map) {
std::vector<std::vector<int64_t>> input_shapes; std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes; std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
...@@ -450,22 +506,220 @@ static void RunInferShapeFunc(framework::InferShapeContext* ctx, ...@@ -450,22 +506,220 @@ static void RunInferShapeFunc(framework::InferShapeContext* ctx,
VLOG(3) << "Custom Operator: InferShape - calc output ddim."; VLOG(3) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes = func(input_shapes, vec_input_shapes, custom_attrs); auto output_shapes = func(input_shapes, vec_input_shapes, custom_attrs);
if (inplace_map.empty()) {
PADDLE_ENFORCE_EQ(outputs.size(),
output_shapes.size(),
phi::errors::InvalidArgument(
"Your custom operator has set the InferShapeFn. "
"However, `Outputs` size = %d does not match the "
"returned vector size of InferShapeFn = %d. Please "
"check InferShapeFn again.",
outputs.size(),
output_shapes.size()));
} else {
PADDLE_ENFORCE_EQ(
outputs.size(),
output_shapes.size() + inplace_map.size(),
phi::errors::InvalidArgument(
"Your custom operator uses `SetInplaceMap` and sets the "
"InferShapeFn. However, `Outputs` size = %d does not match the "
"`InplaceMap size + InferShapeFn output size` = %d. Please check "
"InplaceMap and InferShapeFn again",
outputs.size(),
output_shapes.size() + inplace_map.size()));
}
VLOG(3)
<< "Custom Operator: InferShape - set output ddim: inplace_map.size() = "
<< inplace_map.size()
<< ", output_shapes.size() = " << output_shapes.size();
size_t output_shape_idx = 0;
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
PADDLE_ENFORCE(
inplace_reverse_map.find(out_name) != inplace_reverse_map.end(),
phi::errors::InvalidArgument(
"Custom operator only supports `paddle::Vec(...)` inputs and "
"cannot support `paddle::Vec(...)` output without setting "
"InplaceMap. If you have to use `paddle::Vec(...)` output, "
"please indicate it by setting InplaceMap manully."));
auto in_name = inplace_reverse_map.at(out_name);
ctx->SetOutputsDim(out_name, ctx->GetInputsDim(in_name));
} else {
if (inplace_reverse_map.find(out_name) != inplace_reverse_map.end()) {
// Share dims between inplace inputs and outputs
ctx->ShareDim(inplace_reverse_map.at(out_name), out_name);
} else {
// Set output dims by the output of InferShapeFn
ctx->SetOutputDim(out_name,
phi::make_ddim(output_shapes[output_shape_idx++]));
}
}
}
}
static void RunDefaultInferDtypeFunc(
framework::InferVarTypeContext* ctx,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map) {
if (inplace_map.empty()) { // general case, assure single input and output
PADDLE_ENFORCE_EQ(
inputs.size(),
1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferDtypeFn. "
"At this time, the input dtype will be directly set to "
"the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`"));
PADDLE_ENFORCE_EQ(
outputs.size(),
1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferDtypeFn. "
"At this time, the input dtype will be directly set to "
"the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`"));
VLOG(3) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(inputs[0]);
ctx->SetOutputDataType(outputs[0], dtype);
} else { // inplace case
PADDLE_ENFORCE_EQ(
inplace_map.size(),
outputs.size(),
phi::errors::Unavailable(
"Your custom operator uses `SetInplaceMap` without setting the "
"InferDtypeFn. However, `Outputs` size = %d does not match the "
"`InplaceMap` size = %d. Please check `SetInplaceMap` again or set "
"the InferDtypeFn of custom operator by "
"`.SetInferDtypeFn(PD_INFER_DTYPE(...))`",
outputs.size(),
inplace_map.size()));
for (auto const& pair : inplace_map) {
VLOG(3) << "Custom Operator: InferDtype - inplace dtype: " << pair.first
<< "->" << pair.second;
if (detail::IsDuplicableVar(pair.first)) {
size_t size = ctx->InputSize(pair.first);
for (size_t i = 0; i < size; ++i) {
auto dtype = ctx->GetInputDataType(pair.first, i);
ctx->SetOutputDataType(pair.second, dtype, i);
}
} else {
auto dtype = ctx->GetInputDataType(pair.first);
ctx->SetOutputDataType(pair.second, dtype);
}
}
}
}
VLOG(3) << "Custom Operator: InferShape - set output ddim."; static void RunInferDtypeFunc(
framework::InferVarTypeContext* ctx,
const paddle::InferDtypeFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map,
const std::unordered_map<std::string, std::string>& inplace_reverse_map) {
std::vector<DataType> input_dtypes;
std::vector<std::vector<DataType>> vec_input_dtypes;
VLOG(3) << "Custom Operator: InferDtype - get input dtype.";
for (auto& in_name : inputs) {
if (detail::IsDuplicableVar(in_name)) {
std::vector<DataType> vec_custom_dtype;
if (ctx->HasInput(in_name)) { // general inputs
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
}
} else { // optional inputs, `vec_custom_dtype` is empty
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferDtypeFn "
"cannot find input parameter `%s`",
in_name));
VLOG(3) << "Custom Operator: InferDtypeFn's vector input " << in_name
<< " is optional dtype with None input";
}
vec_input_dtypes.emplace_back(vec_custom_dtype);
} else {
if (ctx->HasInput(in_name)) { // general inputs
auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back(paddle::framework::TransToPhiDataType(dtype));
} else { // optional inputs
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferDtypeFn "
"cannot find input parameter `%s`",
in_name));
input_dtypes.emplace_back(DataType::UNDEFINED);
VLOG(3) << "Custom Operator: InferDtypeFn's input " << in_name
<< " is optional dtype with None input";
}
}
}
VLOG(3) << "Custom Operator: InferDtype - infer output dtype.";
auto output_dtypes = func(input_dtypes, vec_input_dtypes);
if (inplace_map.empty()) {
PADDLE_ENFORCE_EQ(outputs.size(),
output_dtypes.size(),
phi::errors::InvalidArgument(
"Your custom operator has set the InferDtypeFn. "
"However, `Outputs` size = %d does not match the "
"returned vector size of InferDtypeFn = %d. Please "
"check InferDtypeFn again.",
outputs.size(),
output_dtypes.size()));
} else {
PADDLE_ENFORCE_EQ(
outputs.size(),
output_dtypes.size() + inplace_map.size(),
phi::errors::InvalidArgument(
"Your custom operator uses `SetInplaceMap` and sets the "
"InferDtypeFn. However, `Outputs` size = %d does not match the "
"`InplaceMap size + InferDtypeFn output size` = %d. Please check "
"InplaceMap and InferDtypeFn again",
outputs.size(),
output_dtypes.size() + inplace_map.size()));
}
VLOG(3)
<< "Custom Operator: InferDtype - set output dtype: inplace_map.size() = "
<< inplace_map.size()
<< ", output_dtypes.size() = " << output_dtypes.size();
size_t output_dtype_idx = 0;
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i]; auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) { if (detail::IsDuplicableVar(out_name)) {
std::vector<DDim> vec_ddim; PADDLE_ENFORCE(
vec_ddim.reserve(output_shapes.size()); inplace_reverse_map.find(out_name) != inplace_reverse_map.end(),
std::transform(output_shapes.begin(), phi::errors::InvalidArgument(
output_shapes.end(), "Custom operator only supports `paddle::Vec(...)` inputs and "
std::back_inserter(vec_ddim), "cannot support `paddle::Vec(...)` output without setting "
[&](const std::vector<int64_t>& shape) -> DDim { "InplaceMap. If you have to use `paddle::Vec(...)` output, "
return phi::make_ddim(shape); "please indicate it by setting InplaceMap manully."));
}); auto in_name = inplace_reverse_map.at(out_name);
ctx->SetOutputsDim(out_name, vec_ddim); ctx->SetOutputDataTypes(out_name, ctx->GetInputDataTypes(in_name));
} else { } else {
ctx->SetOutputDim(out_name, phi::make_ddim(output_shapes[i])); if (inplace_reverse_map.find(out_name) != inplace_reverse_map.end()) {
auto in_name = inplace_reverse_map.at(out_name);
// Share dtype between inplace inputs and outputs
ctx->SetOutputDataType(out_name, ctx->GetInputDataType(in_name));
} else {
// Set output dtype by the output of InferDtypeFn
ctx->SetOutputDataType(out_name,
paddle::framework::TransToProtoVarType(
output_dtypes[output_dtype_idx++]));
}
} }
} }
} }
...@@ -822,6 +1076,8 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -822,6 +1076,8 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta);
auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta);
auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(base_op_meta); auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(base_op_meta);
auto& op_inplace_reverse_map =
OpMetaInfoHelper::GetInplaceReverseMap(base_op_meta);
auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta); auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta);
auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta); auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta);
auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta); auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta);
...@@ -873,133 +1129,46 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -873,133 +1129,46 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
// InferShape // InferShape
if (infer_shape_func == nullptr) { if (infer_shape_func == nullptr) {
// use default InferShape // use default InferShape
info.infer_shape_ = [op_inputs, op_outputs](InferShapeContext* ctx) { info.infer_shape_ =
PADDLE_ENFORCE_EQ( [op_inputs, op_outputs, op_inplace_map](InferShapeContext* ctx) {
op_inputs.size(), RunDefaultInferShapeFunc(ctx, op_inputs, op_outputs, op_inplace_map);
1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferShapeFn. "
"At this time, the input shape will be directly set to "
"the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(),
1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferShapeFn. "
"At this time, the input shape will be directly set to "
"the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(3) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(op_inputs[0], op_outputs[0]);
}; };
} else { } else {
info.infer_shape_ = [op_inputs, op_outputs, op_attrs, infer_shape_func]( info.infer_shape_ = [op_inputs,
InferShapeContext* ctx) { op_outputs,
RunInferShapeFunc(ctx, infer_shape_func, op_inputs, op_outputs, op_attrs); op_attrs,
op_inplace_map,
op_inplace_reverse_map,
infer_shape_func](InferShapeContext* ctx) {
RunInferShapeFunc(ctx,
infer_shape_func,
op_inputs,
op_outputs,
op_attrs,
op_inplace_map,
op_inplace_reverse_map);
}; };
} }
// Infer Dtype // Infer Dtype
if (infer_dtype_func == nullptr) { if (infer_dtype_func == nullptr) {
// use default InferDtype // use default InferDtype
info.infer_var_type_ = [op_inputs, op_outputs](InferVarTypeContext* ctx) {
PADDLE_ENFORCE_EQ(
op_inputs.size(),
1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferDtypeFn. "
"At this time, the input dtype will be directly set to "
"the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(),
1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferDtypeFn. "
"At this time, the input dtype will be directly set to "
"the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
VLOG(3) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(op_inputs[0]);
ctx->SetOutputDataType(op_outputs[0], dtype);
};
} else {
info.infer_var_type_ = info.infer_var_type_ =
[op_inputs, op_outputs, infer_dtype_func](InferVarTypeContext* ctx) { [op_inputs, op_outputs, op_inplace_map](InferVarTypeContext* ctx) {
std::vector<DataType> input_dtypes; RunDefaultInferDtypeFunc(ctx, op_inputs, op_outputs, op_inplace_map);
std::vector<std::vector<DataType>> vec_input_dtypes; };
VLOG(3) << "Custom Operator: InferDtype - get input dtype.";
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
std::vector<DataType> vec_custom_dtype;
if (ctx->HasInput(in_name)) { // general inputs
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
}
} else { // optional inputs, `vec_custom_dtype` is empty
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferDtypeFn "
"cannot find input parameter `%s`",
in_name));
VLOG(3) << "Custom Operator: InferDtypeFn's vector input "
<< in_name << " is optional dtype with None input";
}
vec_input_dtypes.emplace_back(vec_custom_dtype);
} else {
if (ctx->HasInput(in_name)) { // general inputs
auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
} else { // optional inputs
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferDtypeFn "
"cannot find input parameter `%s`",
in_name));
input_dtypes.emplace_back(DataType::UNDEFINED);
VLOG(3) << "Custom Operator: InferDtypeFn's input " << in_name
<< " is optional dtype with None input";
}
}
}
VLOG(3) << "Custom Operator: InferDtype - infer output dtype.";
auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes);
VLOG(3) << "Custom Operator: InferDtype - set output dtype.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) {
for (size_t j = 0; j < output_dtypes.size(); ++j) {
auto dtype =
paddle::framework::TransToProtoVarType(output_dtypes[i]);
ctx->SetOutputDataType(out_name, dtype, j);
}
} else { } else {
ctx->SetOutputDataType( info.infer_var_type_ = [op_inputs,
out_name, op_outputs,
paddle::framework::TransToProtoVarType(output_dtypes[i])); op_inplace_map,
} op_inplace_reverse_map,
} infer_dtype_func](InferVarTypeContext* ctx) {
RunInferDtypeFunc(ctx,
infer_dtype_func,
op_inputs,
op_outputs,
op_inplace_map,
op_inplace_reverse_map);
}; };
} }
...@@ -1022,6 +1191,8 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1022,6 +1191,8 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_op_inplace_map = OpMetaInfoHelper::GetInplaceMap(cur_grad_op); auto& grad_op_inplace_map = OpMetaInfoHelper::GetInplaceMap(cur_grad_op);
auto& grad_op_inplace_reverse_map =
OpMetaInfoHelper::GetInplaceReverseMap(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op); auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op);
...@@ -1092,6 +1263,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1092,6 +1263,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
return new CustomOperator(type, inputs, outputs, attrs); return new CustomOperator(type, inputs, outputs, attrs);
}; };
// Inplace
if (!grad_op_inplace_map.empty()) {
grad_info.infer_inplace_ = [grad_op_inplace_map](bool use_cuda) {
return grad_op_inplace_map;
};
}
// Grad InferShape // Grad InferShape
if (grad_infer_shape_fn == nullptr) { if (grad_infer_shape_fn == nullptr) {
grad_info.infer_shape_ = [grad_op_inputs, grad_info.infer_shape_ = [grad_op_inputs,
...@@ -1135,12 +1313,16 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1135,12 +1313,16 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
grad_info.infer_shape_ = [grad_op_inputs, grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs, grad_op_outputs,
grad_op_attrs, grad_op_attrs,
grad_op_inplace_map,
grad_op_inplace_reverse_map,
grad_infer_shape_fn](InferShapeContext* ctx) { grad_infer_shape_fn](InferShapeContext* ctx) {
RunInferShapeFunc(ctx, RunInferShapeFunc(ctx,
grad_infer_shape_fn, grad_infer_shape_fn,
grad_op_inputs, grad_op_inputs,
grad_op_outputs, grad_op_outputs,
grad_op_attrs); grad_op_attrs,
grad_op_inplace_map,
grad_op_inplace_reverse_map);
}; };
} }
......
...@@ -518,6 +518,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -518,6 +518,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
"sure you registered your op first and try again. ", "sure you registered your op first and try again. ",
op_type)); op_type));
VLOG(7) << "Run Kernel of Custom Op: " << op_type; VLOG(7) << "Run Kernel of Custom Op: " << op_type;
// TODO(HongyuJia): Optimize Attrs Cast naming and implementation
std::vector<paddle::any> res_attrs = CastAttrsToTargetType( std::vector<paddle::any> res_attrs = CastAttrsToTargetType(
ctx.Attrs(), ctx.Attrs(),
paddle::OpMetaInfoHelper::GetAttrs(meta_info_map.at(op_type)[0])); paddle::OpMetaInfoHelper::GetAttrs(meta_info_map.at(op_type)[0]));
......
...@@ -196,6 +196,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -196,6 +196,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
template <typename... RemainingArgs> template <typename... RemainingArgs>
struct ComputeCallHelper; struct ComputeCallHelper;
// Handle args for general Tensor input case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<const Tensor&, Tail...> { struct ComputeCallHelper<const Tensor&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
...@@ -209,6 +210,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -209,6 +210,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Handle args for optional Tensor input case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<const paddle::optional<paddle::Tensor>&, Tail...> { struct ComputeCallHelper<const paddle::optional<paddle::Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
...@@ -228,6 +230,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -228,6 +230,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Handle args for general vector<Tensor> input case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> { struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
...@@ -241,6 +244,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -241,6 +244,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Handle args for optional vector<Tensor> input case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<const paddle::optional<std::vector<paddle::Tensor>>&, struct ComputeCallHelper<const paddle::optional<std::vector<paddle::Tensor>>&,
Tail...> { Tail...> {
...@@ -293,6 +297,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -293,6 +297,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
// Used to be compatible with 2.3 released internal inplace interface, not // Used to be compatible with 2.3 released internal inplace interface, not
// recommended // recommended
// Handle args for compatible inplace case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<Tensor*, Tail...> { struct ComputeCallHelper<Tensor*, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
...@@ -310,6 +315,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -310,6 +315,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
// recommended // recommended
// TODO(chenweihang): What is the appropriate output form? // TODO(chenweihang): What is the appropriate output form?
// std::vector<Tensor>*? or std::vector<Tensor*>? or std::vector<Tensor*>* // std::vector<Tensor>*? or std::vector<Tensor*>? or std::vector<Tensor*>*
// Handle args for compatible inplace case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<std::vector<Tensor*>, Tail...> { struct ComputeCallHelper<std::vector<Tensor*>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
...@@ -323,7 +329,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -323,7 +329,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Handle Tensor& for inplace case // Handle args for inplace Tensor case
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<Tensor&, Tail...> { struct ComputeCallHelper<Tensor&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
...@@ -337,6 +343,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -337,6 +343,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Handle args for inplace vector<Tensor> case
template <typename... Tail>
struct ComputeCallHelper<std::vector<Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->InputsBetween(range.first, range.second);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
template <int out_idx, typename T> template <int out_idx, typename T>
struct ComputeReturnHelper; struct ComputeReturnHelper;
...@@ -739,6 +759,7 @@ class PADDLE_API OpMetaInfo { ...@@ -739,6 +759,7 @@ class PADDLE_API OpMetaInfo {
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
std::vector<std::string> attrs_; std::vector<std::string> attrs_;
std::unordered_map<std::string, std::string> inplace_map_; std::unordered_map<std::string, std::string> inplace_map_;
std::unordered_map<std::string, std::string> inplace_reverse_map_;
// 2. func info // 2. func info
KernelFunc kernel_fn_{nullptr}; KernelFunc kernel_fn_{nullptr};
InferShapeFunc infer_shape_fn_{nullptr}; InferShapeFunc infer_shape_fn_{nullptr};
...@@ -767,6 +788,10 @@ class OpMetaInfoHelper { ...@@ -767,6 +788,10 @@ class OpMetaInfoHelper {
const paddle::OpMetaInfo& info) { const paddle::OpMetaInfo& info) {
return info.inplace_map_; return info.inplace_map_;
} }
static const std::unordered_map<std::string, std::string>&
GetInplaceReverseMap(const paddle::OpMetaInfo& info) {
return info.inplace_reverse_map_;
}
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) { static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
return info.kernel_fn_; return info.kernel_fn_;
} }
......
...@@ -134,6 +134,7 @@ const std::pair<size_t, size_t>& CustomOpKernelContext::OutputRangeAt( ...@@ -134,6 +134,7 @@ const std::pair<size_t, size_t>& CustomOpKernelContext::OutputRangeAt(
// handle inplace mechanism // handle inplace mechanism
// Find out non-inplace output tensors. // Find out non-inplace output tensors.
// TODO(HongyuJia): Add cache for inplace_tensor_map_ to optimize performance
void CustomOpKernelContext::MapPlainOutputs( void CustomOpKernelContext::MapPlainOutputs(
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
...@@ -215,6 +216,9 @@ OpMetaInfo& OpMetaInfo::SetInplaceMap( ...@@ -215,6 +216,9 @@ OpMetaInfo& OpMetaInfo::SetInplaceMap(
std::unordered_map<std::string, std::string>&& inplace_map) { std::unordered_map<std::string, std::string>&& inplace_map) {
inplace_map_ = inplace_map_ =
std::forward<std::unordered_map<std::string, std::string>>(inplace_map); std::forward<std::unordered_map<std::string, std::string>>(inplace_map);
for (const auto& pair : inplace_map_) {
inplace_reverse_map_[pair.second] = pair.first;
}
return *this; return *this;
} }
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) { OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
......
...@@ -19,18 +19,18 @@ ...@@ -19,18 +19,18 @@
#include "paddle/extension.h" #include "paddle/extension.h"
template <typename data_t> template <typename data_t>
void add_forward_kernel(data_t* x_data, const data_t* y_data, int64_t numel) { void add_data_pointer(const data_t* x_data, data_t* out_data, int64_t numel) {
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
x_data[i] += y_data[i]; out_data[i] += x_data[i];
} }
} }
template <typename data_t> template <typename data_t>
void add_backward_kernel(data_t* y_grad_data, void assign_data_pointer(const data_t* x_data,
const data_t* out_grad_data, data_t* out_data,
int64_t numel) { int64_t numel) {
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
y_grad_data[i] = out_grad_data[i]; out_data[i] = x_data[i];
} }
} }
...@@ -54,23 +54,12 @@ void relu_backward_kernel(const data_t* out_data, ...@@ -54,23 +54,12 @@ void relu_backward_kernel(const data_t* out_data,
void AddForward(paddle::Tensor& x, const paddle::Tensor& y) { // NOLINT void AddForward(paddle::Tensor& x, const paddle::Tensor& y) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
PD_DISPATCH_FLOATING_TYPES(x.type(), "AddForward", ([&] { PD_DISPATCH_FLOATING_TYPES(
add_forward_kernel<data_t>(x.data<data_t>(), x.type(), "AddForward", ([&] {
y.data<data_t>(), add_data_pointer<data_t>(y.data<data_t>(), x.data<data_t>(), x.size());
x.size());
})); }));
} }
std::vector<paddle::DataType> AddInferDtype(const paddle::DataType& x_dtype,
const paddle::DataType& y_dtype) {
return {x_dtype};
}
std::vector<std::vector<int64_t>> AddInferShape(
const std::vector<int64_t>& x_shape, const std::vector<int64_t>& y_shape) {
return {x_shape};
}
std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x, std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
paddle::Tensor& out_grad) { // NOLINT paddle::Tensor& out_grad) { // NOLINT
...@@ -81,8 +70,8 @@ std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x, ...@@ -81,8 +70,8 @@ std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x,
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
out_grad.type(), "AddBackward", ([&] { out_grad.type(), "AddBackward", ([&] {
add_backward_kernel<data_t>( assign_data_pointer<data_t>(
y_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size()); out_grad.data<data_t>(), y_grad.data<data_t>(), out_grad.size());
})); }));
return {y_grad}; return {y_grad};
...@@ -92,9 +81,7 @@ PD_BUILD_OP(custom_add) ...@@ -92,9 +81,7 @@ PD_BUILD_OP(custom_add)
.Inputs({"X", "Y"}) .Inputs({"X", "Y"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetInplaceMap({{"X", "Out"}}) .SetInplaceMap({{"X", "Out"}})
.SetKernelFn(PD_KERNEL(AddForward)) .SetKernelFn(PD_KERNEL(AddForward));
.SetInferShapeFn(PD_INFER_SHAPE(AddInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype));
PD_BUILD_GRAD_OP(custom_add) PD_BUILD_GRAD_OP(custom_add)
.Inputs({"X", "Y", paddle::Grad("Out")}) .Inputs({"X", "Y", paddle::Grad("Out")})
...@@ -102,6 +89,58 @@ PD_BUILD_GRAD_OP(custom_add) ...@@ -102,6 +89,58 @@ PD_BUILD_GRAD_OP(custom_add)
.SetInplaceMap({{paddle::Grad("Out"), paddle::Grad("X")}}) .SetInplaceMap({{paddle::Grad("Out"), paddle::Grad("X")}})
.SetKernelFn(PD_KERNEL(AddBackward)); .SetKernelFn(PD_KERNEL(AddBackward));
// out[i] = x[i] + y
void AddVectorForward(std::vector<paddle::Tensor>& x, // NOLINT
const paddle::Tensor& y) {
PD_CHECK(y.place() == paddle::PlaceType::kCPU, "y must be a CPU Tensor.");
PD_DISPATCH_FLOATING_TYPES(y.type(), "AddVectorForward", ([&] {
for (size_t i = 0; i < x.size(); ++i) {
add_data_pointer<data_t>(y.data<data_t>(),
x[i].data<data_t>(),
y.size());
}
}));
}
// dout[i] / dx[i] = out_grad[i] (do not need any code, inplace automatically)
// dout / dy = out_grad[0] + ... + out_grad[n - 1]
std::vector<paddle::Tensor> AddVectorBackward(
const std::vector<paddle::Tensor>& x,
const paddle::Tensor& y,
std::vector<paddle::Tensor>& out_grad) { // NOLINT
PD_CHECK(x[0].place() == paddle::PlaceType::kCPU,
"x[0] must be a CPU Tensor.");
PD_CHECK(y.place() == paddle::PlaceType::kCPU, "y must be a CPU Tensor.");
PD_CHECK(x.size() == out_grad.size(),
"x must have the same size as out_grad.");
paddle::Tensor y_grad = paddle::zeros(y.shape(), y.dtype(), y.place());
PD_DISPATCH_FLOATING_TYPES(
y.type(), "AddVectorBackward", ([&] {
// y_grad = out_grad[0] + ... + out_grad[n - 1]
for (size_t i = 0; i < out_grad.size(); ++i) {
add_data_pointer<data_t>(
out_grad[i].data<data_t>(), y_grad.data<data_t>(), y_grad.size());
}
}));
return {y_grad};
}
PD_BUILD_OP(custom_add_vec)
.Inputs({paddle::Vec("X"), "Y"})
.Outputs({paddle::Vec("Out")})
.SetInplaceMap({{paddle::Vec("X"), paddle::Vec("Out")}})
.SetKernelFn(PD_KERNEL(AddVectorForward));
PD_BUILD_GRAD_OP(custom_add_vec)
.Inputs({paddle::Vec("X"), "Y", paddle::Grad(paddle::Vec("Out"))})
.Outputs({paddle::Grad(paddle::Vec("X")), paddle::Grad("Y")})
.SetInplaceMap({{paddle::Grad(paddle::Vec("Out")),
paddle::Grad(paddle::Vec("X"))}})
.SetKernelFn(PD_KERNEL(AddVectorBackward));
void MultiInplaceForward(paddle::Tensor& x, // NOLINT void MultiInplaceForward(paddle::Tensor& x, // NOLINT
const paddle::Tensor& y, const paddle::Tensor& y,
paddle::Tensor& a, // NOLINT paddle::Tensor& a, // NOLINT
...@@ -111,29 +150,11 @@ void MultiInplaceForward(paddle::Tensor& x, // NOLINT ...@@ -111,29 +150,11 @@ void MultiInplaceForward(paddle::Tensor& x, // NOLINT
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
x.type(), "MultiInplaceForward", ([&] { x.type(), "MultiInplaceForward", ([&] {
add_forward_kernel<data_t>( add_data_pointer<data_t>(y.data<data_t>(), x.data<data_t>(), x.size());
x.data<data_t>(), y.data<data_t>(), x.size()); add_data_pointer<data_t>(b.data<data_t>(), a.data<data_t>(), a.size());
add_forward_kernel<data_t>(
a.data<data_t>(), b.data<data_t>(), a.size());
})); }));
} }
std::vector<paddle::DataType> MultiInplaceInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& y_dtype,
const paddle::DataType& a_dtype,
const paddle::DataType& b_dtype) {
return {x_dtype, a_dtype};
}
std::vector<std::vector<int64_t>> MultiInplaceInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& y_shape,
const std::vector<int64_t>& a_shape,
const std::vector<int64_t>& b_shape) {
return {x_shape, a_shape};
}
std::vector<paddle::Tensor> MultiInplaceBackward( std::vector<paddle::Tensor> MultiInplaceBackward(
const paddle::Tensor& x, const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
...@@ -151,11 +172,11 @@ std::vector<paddle::Tensor> MultiInplaceBackward( ...@@ -151,11 +172,11 @@ std::vector<paddle::Tensor> MultiInplaceBackward(
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
outxy_grad.type(), "MultiInplaceBackward", ([&] { outxy_grad.type(), "MultiInplaceBackward", ([&] {
add_backward_kernel<data_t>(y_grad.data<data_t>(), assign_data_pointer<data_t>(outxy_grad.data<data_t>(),
outxy_grad.data<data_t>(), y_grad.data<data_t>(),
outxy_grad.size()); outxy_grad.size());
add_backward_kernel<data_t>(b_grad.data<data_t>(), assign_data_pointer<data_t>(outab_grad.data<data_t>(),
outab_grad.data<data_t>(), b_grad.data<data_t>(),
outab_grad.size()); outab_grad.size());
})); }));
...@@ -166,9 +187,7 @@ PD_BUILD_OP(custom_multi_inplace) ...@@ -166,9 +187,7 @@ PD_BUILD_OP(custom_multi_inplace)
.Inputs({"X", "Y", "A", "B"}) .Inputs({"X", "Y", "A", "B"})
.Outputs({"OutXY", "OutAB"}) .Outputs({"OutXY", "OutAB"})
.SetInplaceMap({{"X", "OutXY"}, {"A", "OutAB"}}) .SetInplaceMap({{"X", "OutXY"}, {"A", "OutAB"}})
.SetKernelFn(PD_KERNEL(MultiInplaceForward)) .SetKernelFn(PD_KERNEL(MultiInplaceForward));
.SetInferShapeFn(PD_INFER_SHAPE(MultiInplaceInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MultiInplaceInferDtype));
PD_BUILD_GRAD_OP(custom_multi_inplace) PD_BUILD_GRAD_OP(custom_multi_inplace)
.Inputs({"X", "Y", paddle::Grad("OutXY"), "A", "B", paddle::Grad("OutAB")}) .Inputs({"X", "Y", paddle::Grad("OutXY"), "A", "B", paddle::Grad("OutAB")})
......
...@@ -40,6 +40,54 @@ custom_inplace = load( ...@@ -40,6 +40,54 @@ custom_inplace = load(
verbose=True, verbose=True,
) )
# Temporarily assemble custom python API
import paddle.fluid.core as core
from paddle.fluid.core import CustomOpKernelContext
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
def custom_add_vec(x_vector, y):
# prepare inputs and outputs
attrs = {}
outs = {}
out_names = ["Out@VECTOR"]
# The output variable's dtype use default value 'float32',
# and the actual dtype of output variable will be inferred in runtime.
if in_dygraph_mode():
ctx = CustomOpKernelContext()
for i in [x_vector, y]:
ctx.add_inputs(i)
for out_name in out_names:
outs[out_name] = [core.eager.Tensor() for _ in range(len(x_vector))]
ctx.add_outputs(outs[out_name])
core.eager._run_custom_op(ctx, "custom_add_vec", True)
else:
ins = {}
for key, value in dict({"X@VECTOR": x_vector, "Y": y}).items():
# handle optional inputs
if value is not None:
ins[key] = value
helper = LayerHelper("custom_add_vec", **locals())
for out_name in out_names:
outs[out_name] = [
helper.create_variable(dtype='float32')
for _ in range(len(x_vector))
]
helper.append_op(
type="custom_add_vec", inputs=ins, outputs=outs, attrs=attrs
)
res = [outs[out_name] for out_name in out_names]
return res[0] if len(res) == 1 else res
# Set custom python API manually
custom_inplace.custom_add_vec = custom_add_vec
def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y):
paddle.set_device(device) paddle.set_device(device)
...@@ -88,7 +136,89 @@ def inplace_static_add(func, device, dtype, np_x, np_y): ...@@ -88,7 +136,89 @@ def inplace_static_add(func, device, dtype, np_x, np_y):
return x_v, out_v, x_grad_v, y_grad_v, out_grad_v return x_v, out_v, x_grad_v, y_grad_v, out_grad_v
def inplace_dynamic_relu(phi_func, device, dtype, np_x, np_y, np_z): def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y):
paddle.set_device(device)
inputs = [
paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True)
for np_input in np_inputs
]
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
if phi_func:
out = custom_inplace.custom_add_vec(inputs, y)
else:
out = [x.add_(y) for x in inputs]
mean_out = paddle.mean(paddle.concat(out))
mean_out.backward()
return (
np.concatenate([input.numpy() for input in inputs]),
y.numpy(),
np.concatenate([o.numpy() for o in out]),
np.concatenate([input.grad.numpy() for input in inputs]),
y.grad.numpy(),
)
def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x1 = static.data(
name="x1", shape=[None, np_inputs[0].shape[1]], dtype=dtype
)
x2 = static.data(
name="x2", shape=[None, np_inputs[1].shape[1]], dtype=dtype
)
y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype)
x1.stop_gradient = False
x2.stop_gradient = False
y.stop_gradient = False
if phi_func:
out = custom_inplace.custom_add_vec([x1, x2], y)
else:
out = [paddle.add(x1, y), paddle.add(x2, y)]
mean_out = paddle.mean(paddle.concat(out))
static.append_backward(mean_out)
exe = static.Executor()
exe.run(static.default_startup_program())
(
out0_v,
out1_v,
x1_grad_v,
x2_grad_v,
y_grad_v,
out0_grad_v,
out1_grad_v,
) = exe.run(
static.default_main_program(),
feed={
"x1": np_inputs[0].astype(dtype),
"x2": np_inputs[1].astype(dtype),
"y": np_y.astype(dtype),
},
fetch_list=[
out[0].name,
out[1].name,
x1.name + "@GRAD",
x2.name + "@GRAD",
y.name + "@GRAD",
out[0].name + "@GRAD",
out[1].name + "@GRAD",
],
)
paddle.disable_static()
return (
[out0_v, out1_v],
[x1_grad_v, x2_grad_v],
y_grad_v,
[out0_grad_v, out1_grad_v],
)
def inplace_dynamic_relu_net(phi_func, device, dtype, np_x, np_y, np_z):
paddle.set_device(device) paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
...@@ -107,7 +237,7 @@ def inplace_dynamic_relu(phi_func, device, dtype, np_x, np_y, np_z): ...@@ -107,7 +237,7 @@ def inplace_dynamic_relu(phi_func, device, dtype, np_x, np_y, np_z):
return x.numpy(), y.numpy(), out.numpy(), x.grad.numpy(), y.grad.numpy() return x.numpy(), y.numpy(), out.numpy(), x.grad.numpy(), y.grad.numpy()
def inplace_static_relu(func, device, dtype, np_x, np_y, np_z): def inplace_static_relu_net(func, device, dtype, np_x, np_y, np_z):
paddle.enable_static() paddle.enable_static()
paddle.set_device(device) paddle.set_device(device)
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
...@@ -255,6 +385,10 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -255,6 +385,10 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z = np.random.random((3, 2)).astype("float32") self.np_z = np.random.random((3, 2)).astype("float32")
self.np_a = np.random.random((3, 2)).astype("float32") self.np_a = np.random.random((3, 2)).astype("float32")
self.np_b = np.random.random((3, 2)).astype("float32") self.np_b = np.random.random((3, 2)).astype("float32")
self.np_inputs = [
np.random.random((3, 2)).astype("float32"),
np.random.random((3, 2)).astype("float32"),
]
def check_output(self, out, pd_out, name): def check_output(self, out, pd_out, name):
np.testing.assert_array_equal( np.testing.assert_array_equal(
...@@ -354,7 +488,79 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -354,7 +488,79 @@ class TestCustomInplaceJit(unittest.TestCase):
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(phi_y_grad, pd_y_grad, "y_grad")
def test_static_multiple_inplace_relu(self): def test_static_add_vector(self):
for device in self.devices:
for dtype in self.dtypes:
(
pd_out,
pd_x_grad,
pd_y_grad,
pd_out_grad,
) = inplace_static_add_vector(
True,
device,
dtype,
self.np_inputs,
self.np_y,
)
(
phi_out,
phi_x_grad,
phi_y_grad,
phi_out_grad,
) = inplace_static_add_vector(
False,
device,
dtype,
self.np_inputs,
self.np_y,
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_grad, pd_out_grad, "out_grad")
def test_dynamic_add_vector(self):
for device in self.devices:
for dtype in self.dtypes:
(
pd_x,
pd_y,
pd_out,
pd_x_grad,
pd_y_grad,
) = inplace_dynamic_add_vector(
True,
device,
dtype,
self.np_inputs,
self.np_y,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
) = inplace_dynamic_add_vector(
False,
device,
dtype,
self.np_inputs,
self.np_y,
)
self.check_output(phi_x, phi_out, "inplace_phi_x")
self.check_output(pd_x, pd_out, "inplace_pd_x")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
def test_static_relu_net(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
( (
...@@ -363,7 +569,7 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -363,7 +569,7 @@ class TestCustomInplaceJit(unittest.TestCase):
pd_out, pd_out,
pd_x_grad, pd_x_grad,
pd_y_grad, pd_y_grad,
) = inplace_static_relu( ) = inplace_static_relu_net(
paddle.nn.functional.relu, paddle.nn.functional.relu,
device, device,
dtype, dtype,
...@@ -377,7 +583,7 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -377,7 +583,7 @@ class TestCustomInplaceJit(unittest.TestCase):
phi_out, phi_out,
phi_x_grad, phi_x_grad,
phi_y_grad, phi_y_grad,
) = inplace_static_relu( ) = inplace_static_relu_net(
custom_inplace.custom_relu_inplace, custom_inplace.custom_relu_inplace,
device, device,
dtype, dtype,
...@@ -391,7 +597,7 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -391,7 +597,7 @@ class TestCustomInplaceJit(unittest.TestCase):
self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad") self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad")
self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad") self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad")
def test_dynamic_multiple_inplace_relu(self): def test_dynamic_relu_net(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
( (
...@@ -400,7 +606,7 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -400,7 +606,7 @@ class TestCustomInplaceJit(unittest.TestCase):
pd_out, pd_out,
pd_x_grad, pd_x_grad,
pd_y_grad, pd_y_grad,
) = inplace_dynamic_relu( ) = inplace_dynamic_relu_net(
False, False,
device, device,
dtype, dtype,
...@@ -414,7 +620,7 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -414,7 +620,7 @@ class TestCustomInplaceJit(unittest.TestCase):
phi_out, phi_out,
phi_x_grad, phi_x_grad,
phi_y_grad, phi_y_grad,
) = inplace_dynamic_relu( ) = inplace_dynamic_relu_net(
True, True,
device, device,
dtype, dtype,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册