[RFC] computeKnownBits recursion depth

computeKnownBits calls are used widely throughout optimization passes, as they enable bit aware optimizations / code canonicalization.

As an example, InstCombine will use computeKnownBits to canonicalize xor A, B to or disjoint A, B if the requested bits of A and B are disjoint – that is, there is no position where A and B are both 1.

These computeKnownBits calls recursively analyze the instruction’s operands to hopefully get increasing information about the KnownBits of the instruction in question. Thus, for long instruction sequences, computing the KnownBits could involve a potentially expensive recursive analysis, and these calls are a common source of increased compile time Make LLVM fast again

To control for this, we use a hardcoded depth limit of 6 to control recursion. However, in certain cases this can lead to significant degradations in the quality of the generated code.

A notable example is Triton’s Linear Layout: triton/include/triton/Tools/LinearLayout.h at 981e987eed9053b952f81153bc0779c99d8c642e · triton-lang/triton · GitHub . This feature provides a simplified way of mapping hardware data locations in GPUs to Tensor Indexes. As a result, less computation is needed to calculate memory addresses. However, by design of this feature, the memory address calculations now involve xors.

In a specific function using Linear Layout, we end up with the following IR:

	%130 = or disjoint i32 %128, %129, !dbg !34
	%131 = getelementptr inbounds nuw half, ptr addrspace(3) @global_smem, i32 %130, !dbg !34
	store <4 x i32> %71, ptr addrspace(3) %131, align 16, !dbg !34
	%132 = xor i32 %130, 4096, !dbg !34
	%133 = getelementptr inbounds nuw half, ptr addrspace(3) @global_smem, i32 %132, !dbg !34
	store <4 x i32> %74, ptr addrspace(3) %133, align 16, !dbg !34
	%134 = xor i32 %130, 8192, !dbg !34
	%135 = getelementptr inbounds nuw half, ptr addrspace(3) @global_smem, i32 %134, !dbg !34
	store <4 x i32> %77, ptr addrspace(3) %135, align 16, !dbg !34
	%136 = xor i32 %130, 12288, !dbg !34
	%137 = getelementptr inbounds nuw half, ptr addrspace(3) @global_smem, i32 %136, !dbg !34
	store <4 x i32> %80, ptr addrspace(3) %137, align 16, !dbg !34

We see the final instruction involved in base address calculation is an xor. These xors have disjoint operands, but we do not canonicalize to or disjoint due to the recursion depth limit.

For a specific xor (%132), the entirety of the address calculation is:

	%28 = tail call i32 @llvm.amdgcn.workitem.id.x(), !dbg !26
	%29 = and i32 %28, 8, !dbg !26
	%.not = icmp eq i32 %29, 0, !dbg !26
	%30 = and i32 %28, 16, !dbg !26
	%31 = icmp eq i32 %30, 0, !dbg !26
	%32 = and i32 %28, 32, !dbg !26
	%33 = icmp eq i32 %32, 0, !dbg !26
	%34 = and i32 %28, 256, !dbg !26
	%53 = shl i32 %28, 3, !dbg !29
	%54 = and i32 %53, 56, !dbg !29 
	%121 = select i1 %.not, i32 0, i32 72
	%122 = select i1 %31, i32 0, i32 144
	%123 = or disjoint i32 %121, %122
	%124 = select i1 %33, i32 0, i32 288
	%125 = or disjoint i32 %123, %124
	%126 = xor i32 %125, %54
	%127 = and i32 %53, 1536
	%128 = or disjoint i32 %127, %126
	%129 = shl nuw nsw i32 %34, 3
	%130 = or disjoint i32 %128, %129, !dbg !34
	%132 = xor i32 %130, 4096, !dbg !34

For the %132 Instruction, the longest recursion path before getting to a leaf node is:

%130%128%126%125%123%121%.not%29%28

Thus, the recursion tree exceeds the recursion depth and we end up with xors as the final instruction in the address calculation. SeparateConstOffsetFromGEP does not currently analyze xor operands, so it does not attempt to fold the constants into the GEPs. For these particular GEPs, the users of the GEPs (stores) occur immediately after the GEPs, so the CodeGenPrepare address sinking does not apply. However, we have load users in different blocks that use GEPs which are not sunk since our target (AMDGPU) does not find a matching AddrMode. As a result, we end up with many base addresses for these instead of using a few bases with constant offsets.

For this particular case, if we use a new computeKnownBits API which overrides the depth limit, the result is to use 18 less registers for the load addresses. Importantly, this has also resulted in eliminating spilling, and has doubled the FLOPS.

The conclusion is that the recursion depth limit is significantly hurting the runtime performance. We are exploring solutions which clean up the instruction sequences involved in calculating the addresses, but this type of approach is not very robust, and may result in missed optimizations. The more stable solution is to make the recursion depth limit more flexible. It seems like some computeKnownBits based optimizations (e.g. xoror disjoint canonicalization for SeparateConstOffsetFromGEP) can provide a stronger performance uplift than others. Thus, it makes sense to me that some clients use a more relaxed depth limit than others. However, I’m curios if the community agrees / what the consensus is for issues of this type.

It would be great if we could have a KnownBitsAnalysis that cached the results of these queries, instead of recomputing them for every subexpression on every call to computeKnownBits. That could remove the need for any kind of recursion limit.

IIUC this is hard to implement because we would need a reliable way to invalidate parts of the cache whenever the IR is modified.

I am not very familiar with computeknownbits infrastructure.
Here are some of the things I can think of based on making this recursion depth dynamic

  1. Based on function complexity like numer of basic blocks/instructions, presence of loops, CFG structure
  2. If we can track how much of progress we make at each swipe (how many new bits are known), we can reason about whether to increase the recursion count or not.

If we know why the recursion depth is set to 6 (is it because of loops, phis feeding phis ?), we can make the recursion depth dynamic too bypassing these common cases.

Thanks for comment!

Yeah I was thinking of something like this as well – for the invalidation problem we could build the ValueTracking stuff into an Analysis pass and use the normal invalidation mechanisms. But then, I would think most passes which would use the ValueTracking pass would most likely invalidate it, which could actually end up making things more expensive. Not to mention the overhead of using a pass vs the current approach of static util functions.

I think the issue is that we are looking top down at a problem that is defined bottom up. For example, if we integrated the KnownBits into the instructions themselves, then we could compute the KnownBits based on the operands during construction. With this approach, we may be able to get perfect info in O(n) or close to it (n being number of operands). I wonder if we should work towards this direction in the long-term / there may be some fundamental issues with this approach that I’m not thinking of?

Thanks for comment! Interesting thought to make the depth limit dynamic – I hadn’t really considered this. My gut feeling is that this would create more things to maintain and measure, though we would probably get closer to the optimal tradeoff of compile time vs performance.

The basic problem is that the query needs to be cheap, because we do it all over the place. Computing known bits for the whole function isn’t actually that expensive (Attributor and SCCP do similar computations), but you’d need to cache it. And keeping any sort of cache up-to-date as code is transformed is hard and expensive: it’s hard to prove a given change doesn’t invalidate the whole cache, unless you restrict what changes are actually allowed.

Playing with the recursion depth might help a little for specific cases, but you’ll always hit a cutoff at some point, so I’m not sure that’s really the best approach.

1 Like

My suspicion just sticking the KnownBits on instructions can have invalidation issues, especially after in-place mutations (which’re decently common).

For instance, suppose I go set nneg on a zext. That’ll need to trigger every transitive user of that zext to re-evaluate their known bits … and require the zext itself to update its known bits (and you’d better hope you didn’t miss a call site)

Cache invalidation on changing IR is a problem in general but should be manageable for computeKnownBits. Case analysis goes:

  • Previously unknown bit becomes computable and we don’t notice, safe
  • Previously known bit is no longer computable and we don’t clear it, weird but presumably safe
  • Previously known bit is now believed to be the other bit, miscompilation

That suggests we could have a cache / bits on the value / side table / whatever containing the current best guess at KnownBits, and accept that it’s not necessarily at fixpoint. When we want to do the call, try the cached value. If it’s not enough info for the call site, either move on or try to compute some more, maybe to the current recursion depth, adding any new information discovered to the cache as we go.

I wouldn’t say that’s necessarily a brilliant thing for compiler devs trying to work out why an optimisation missed but I think it is legitimately a case where only sporadically fixing up the cache information could be safe.

Not safe. This could actually be case three (i.e. a change in known bits), but the implementation happens to not be able to prove the new value. Reporting the old value is still unsound.

The invalidation problem is hard, but it is further complicated by context-sensitivity (Q.CtxI): we need to drop it altogether if we’re thinking about a KnownBitsAnalysis. The regression from dropping it might be relatively small though.

I think invalidating the cache and re-computing will be too expensive, and it might not be possible to craft a partial invalidation without knowing exactly what the IR changes in question are. Is it then possible to incrementally-update KnownBits every time there’s an IR change in the caller? We could have a KnownBitsTree which is similar to DominatorTree with KnownBits at each node, and wrap up the APIs of IR changes in a KnownBitsUpdater. So what I’m thinking of is something like:

KnownBitsUpdater {
  KnownBitsTree<Value *> KnownTree;
  ...
};

KnownBitsUpdater::compute(Function &F);
KnownBitsUpdater::replaceAllUsesWith(Value *V, Value *New) {
   V->replaceAllUsesWith(New);
   // compute KnownBits of New, and propagate it in the tree
}
KnownBitsUpdater::setNoWrapFlags(Value *, NoWrapFlags NW);
...