20_simt_canonical
这是一个经典cuda core 做gemm计算的例子,核心是掌握warp gemm级别的配置。
Policy
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
cutlass::MatrixShape<4, 8>,
cutlass::layout::RowMajorInterleaved<2>,
cutlass::gemm::GemmShape<4, 4, 1>
>;
Policy其实是一个warp的线程位置安排的策略,这里安排的是4行8列,那么具体线程号需要对应二维坐标中的哪一个小方格,这里可以看到用的是RowMajorInterleaved<2>。可以用下面的代码看到排布。
// auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout();
// MatrixCoord lane_offset = lane_layout.inverse(lane_id);
// printf("threadIdx.x==%d, ##[%d, %d] \n",threadIdx.x, lane_offset.row(), lane_offset.column());
最后是一个线程做一个4*4大小的矩阵,k这里等于1,也就叫做lane shape(LaneMmaShape).
warp iterator
先看A矩阵的。
typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id);
先看IteratorA 这个数据类型,具体定义在mma_smit.h
/// Iterates over the A operand in memory
using IteratorA = MmaSimtTileIterator<
MatrixShape<Shape::kM, Policy::LaneMmaShape::kK>, //单次数据块的大小
Operand::kA,//有A, B, C, D主要是为了后面偏特化使用的
ElementA,
LayoutA,
Policy,//也就是上面的policy
PartitionsK,
Shape::kK // 整个warp gemm 的k值
>;
接下来看看这个具体IteratorA 的类型定义,我们可以直接看看偏特化版本:在mma_smit_tile_iterator.h
/// Specialization for A operands of row-major layouts
///
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
///
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Data type of A elements
typename Element_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Number of partitions along K dimension - used in sliced-K
int PartitionsK,
/// Group Size along kPartition - used in sliced-K
int PartitionGroupSize
>
class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize> {}
其中这个例子用的构造函数是:
MmaSimtTileIterator(
TensorRef ref,
TensorCoord extent,
int lane_id
) : extent_(extent), divisible_ (false) {
// compute offset based on thread ID and lane layout
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
MatrixCoord(Policy::LaneMmaShape::kM, 0);
origin_ = lane_offset;
ref.add_coord_offset(lane_offset);
ref_.reset(ref.data(), ref.stride(0));
if(threadIdx.x==0) printf("@##########################@\n");
}