提交 c1545a4b 编写于 作者: L liangjianzhong

compile 1

上级 701d3fa0
...@@ -2,3 +2,13 @@ cc_library( ...@@ -2,3 +2,13 @@ cc_library(
dist_tensor_spec dist_tensor_spec
SRCS dist_tensor_spec.cc SRCS dist_tensor_spec.cc
DEPS dist_attr) DEPS dist_attr)
cc_library(
spmd_rule_base
SRCS common.cc
DEPS dist_tensor_spec)
cc_library(
matmul_spmd_rule
SRCS matmul_spmd_rule.cc
DEPS spmd_rule_base)
...@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::vector<TensorDistAttr> MatmulSPMDRule::InferForward( std::vector<TensorDistAttr> MatmulSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs, const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
...@@ -29,26 +30,31 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferForward( ...@@ -29,26 +30,31 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferForward(
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of InputSpec of matmul should be 2, but got [%d].", "The size of InputSpec of matmul should be 2, but got [%d].",
input_specs_size)); input_specs_size));
auto x_shape = input_specs[0].get_shape();
int x_ndim = input_specs[0].shape.size(); auto y_shape = input_specs[1].get_shape();
int y_ndim = input_specs[1].shape.size(); int x_ndim = x_shape.size();
std::vector<int64_t> x_dims_mapping = input_specs[0].DistAttr.dims_mapping; int y_ndim = y_shape.size();
std::vector<int64_t> y_dims_mapping = input_specs[0].DistAttr.dims_mapping; auto x_dist_attr_src = input_specs[0].get_dist_attr();
auto y_dist_attr_src = input_specs[1].get_dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
std::vector<int64_t> y_dims_mapping = y_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_ndim, x_ndim,
x_dims_mapping.size() phi::errors::InvalidArgument( x_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].", "Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].",
x_ndim, x_ndim,
x_dims_mapping.size())); x_dims_mapping.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
y_ndim, y_ndim,
y_dims_mapping.size() phi::errors::InvalidArgument( y_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of Y's tensor size: [%d] and Y's dims_mapping size [%d].", "Mismatch of Y's tensor size: [%d] and Y's dims_mapping size [%d].",
x_ndim, x_ndim,
x_dims_mapping.size())); x_dims_mapping.size()));
bool trans_x = ExtractAttr<bool>("trans_x"); bool trans_x = ExtractAttr<bool>("trans_x", attrs);
bool trans_y = ExtractAttr<bool>("trans_y"); bool trans_y = ExtractAttr<bool>("trans_y", attrs);
// step1: build Einsum Notation // step1: build Einsum Notation
...@@ -129,19 +135,19 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferForward( ...@@ -129,19 +135,19 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferForward(
// step2.2: Infer Output's Dims Mapping. // step2.2: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst = TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[0].DistAttr); CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping; std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size()); out_dims_mapping.reserve(out_axes.size());
for (int i = 0; i < out_axes.size(); ++i) { for (size_t i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]); out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
} }
output_dist_attr_dst.set_dims_mapping(out_dims_mapping); output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step2.3: Merge and get Inputs' New Dims Mapping. // step2.3: Merge and get Inputs' New Dims Mapping.
TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( TensorDistAttr x_dist_attr_dst =
input_specs[0].DistAttr, input_specs[0].shape, x_axes, axis_to_dim_map); GetInferedDistAttr(x_dist_attr_src, x_shape, x_axes, axis_to_dim_map);
TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( TensorDistAttr y_dist_attr_dst =
input_specs[1].DistAttr, input_specs[1].shape, y_axes, axis_to_dim_map); GetInferedDistAttr(y_dist_attr_src, y_shape, y_axes, axis_to_dim_map);
// step2.3: Handle Partial // step2.3: Handle Partial
// Step2.3.1 Output Partial // Step2.3.1 Output Partial
...@@ -149,18 +155,17 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferForward( ...@@ -149,18 +155,17 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferForward(
ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
// Step2.3.2 handle input tensor partial (TODO) // Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward: " VLOG(4) << "MatmulSPMDRule InferForward: "
<< "X shape: " << input_specs[0].shape << "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< ", src_dims_mapping: " << x_dims_mapping << str_join(x_dims_mapping) << "], dst_dims_mapping: ["
<< ", dst_dims_mapping: " << x_dist_attr_dst.dims_mapping << str_join(x_dist_attr_dst.dims_mapping()) << "]; Y shape: ["
<< "; Y shape: " << input_specs[1].shape << str_join(y_shape) << "], src_dims_mapping: ["
<< ", src_dims_mapping: " << x_dims_mapping << str_join(x_dims_mapping) << "], dst_dims_mapping: ["
<< ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping << str_join(y_dist_attr_dst.dims_mapping())
<< "; Output dims_mapping: " << out_dims_mapping << "]; Output dims_mapping: [" << str_join(out_dims_mapping)
<< ", partial_on_dims: " << partial_on_dims; << "], partial_on_dims: [" << str_join(partial_on_dims) << "]";
return { x_dist_attr_dst, y_dist_attr_dst, output_dist_attr_dst } return {x_dist_attr_dst, y_dist_attr_dst, output_dist_attr_dst};
} }
TensorDistAttr GetInferedDistAttr( TensorDistAttr GetInferedDistAttr(
...@@ -170,13 +175,19 @@ TensorDistAttr GetInferedDistAttr( ...@@ -170,13 +175,19 @@ TensorDistAttr GetInferedDistAttr(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map) { const std::unordered_map<std::string, int64_t>& axis_to_dim_map) {
TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr); TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr);
std::vector<int64_t> infered_dims_mapping; std::vector<int64_t> infered_dims_mapping;
infered_dims_mapping.reserve(tensor_string.size()); infered_dims_mapping.reserve(tensor_axis.size());
for (int i = 0; i < tensor_axis.size(); ++i) { for (size_t i = 0; i < tensor_axis.size(); ++i) {
if (shape.size() > i && shape[i] == 1) { if (shape.size() > i && shape[i] == 1) {
infered_dims_mapping.push_back(-1); infered_dims_mapping.push_back(-1);
} else { } else {
infered_dims_mapping.push_back(axis_to_dim_map[tensor_axis.substr(i, 1)]); auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1));
if (itr == axis_to_dim_map.end()) {
phi::errors::InvalidArgument(
"Tensor axis [%s] of not in axis_to_dim_map.",
tensor_axis.substr(i, 1));
}
infered_dims_mapping.push_back(itr->second);
} }
} }
...@@ -186,7 +197,12 @@ TensorDistAttr GetInferedDistAttr( ...@@ -186,7 +197,12 @@ TensorDistAttr GetInferedDistAttr(
std::vector<TensorDistAttr> MatmulSPMDRule::InferBackward( std::vector<TensorDistAttr> MatmulSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs, const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {} const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of MatmulSPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
......
...@@ -33,14 +33,14 @@ TensorDistAttr GetInferedDistAttr( ...@@ -33,14 +33,14 @@ TensorDistAttr GetInferedDistAttr(
class MatmulSPMDRule : public SPMDRuleBase { class MatmulSPMDRule : public SPMDRuleBase {
public: public:
std::vector<DistTensorSpec> InferForward( std::vector<TensorDistAttr> InferForward(
const std::vector<DistTensorSpec>& input_specs, const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
std::vector<DistTensorSpec> InferBackward( std::vector<TensorDistAttr> InferBackward(
const std::vector<DistTensorSpec>& output_specs, const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
} };
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册