I tried to order my thoughts around this and came up with the following description of what I would like to do:
Operations
These are descriptions of what the operations should compute not how they should be implemented (see next section for that).
Reduction
Outputs a scalar containing the sum / product / minimum / maximum / bitwise AND / bitwise OR / bitwise XOR of all elements in the input array.
Exclusive / Inclusive Scan
For each element reduce all elements up to that element, including or excluding that element respectively.
Aggregation
Partition (key-value-pair) elements by the provided unsigned integer key. Note: This is not a generic sort as it is not comparison-based and can not handle floats or other types of keys.
Selection
An aggregation with a 1 bit key and one partition is discarded. Thus, it selects the elements by a boolean condition.
Binning
Sorts first using aggregation, then reduces each partition to a bin. So the number of different keys in the input is the same as the number of bins in the output (which is a histogram).
Run Length Encoding
The input is already sorted (e.g. by an aggregation), count the number of elements (which share the same key) in each partition.
FFT
Swap time and frequency domain of a complex input to a complex output, encoded as vec2 / 2D floats. Reverse direction can be done by additionally multiplying with a 1/N scaling factor.
Convolution
Uses one input array as sliding window over another input array and outputs the dot product for each shift offset.
Possible Implementations
- Reduction: Already exists up to workgroup level
- Scan, Aggregation, Selection, Binning, Run Length Encoding: Single-pass Parallel Prefix Scan with Decoupled Look-back
- FFT: Cooley–Tukey algorithm with power-of-two radices.
- Convolution: Perform one FFT for the two inputs each, multiply them element wise and do another (inverse) FFT on the product to get the output.
Hierarchy
Many of these operations can be build by composition form the lower levels upwards.
It might make sense to expose these lower level versions as well.
- Low level: Subgroup / warp / wavefront / SIMD-vector
- Mid level: Workgroup / thread block
- High level: Kernel dispatch / grid
Higher Dimensions
Scan, FFT and convolution can be done in arbitrary dimensions like this:
- Run 1D once for each row
- Transpose
- Run 1D once of each column
- Transpose
- etc …
Plan
I would start out by implementing reduction for the other two levels, then do the scan related operations and finally the FFT. I guess some would be a set of Op-Specific passes inside the GPU dialect and others need lowering to spv, nvvm etc.