21#include "llvm/IR/GlobalVariable.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Metadata.h"
24#include "llvm/IR/Module.h"
25#include "llvm/IR/Value.h"
26#include "llvm/Support/Alignment.h"
28#include "llvm/Support/FormatVariadic.h"
31using namespace CodeGen;
37void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
41 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
42 Version.getSubminor() || !Version.getMinor()) {
47 uint64_t Minor = *Version.getMinor();
49 auto &Ctx = M.getContext();
50 IRBuilder<> B(M.getContext());
51 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
52 ConstantAsMetadata::get(B.getInt32(Minor))});
53 StringRef DXILValKey =
"dx.valver";
54 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
55 DXILValMD->addOperand(Val);
57void addDisableOptimizations(llvm::Module &M) {
58 StringRef Key =
"dx.disable_optimizations";
59 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
85 std::vector<llvm::Type *> EltTys;
87 GlobalVariable *GV =
Const.first;
88 Const.second = EltTys.size();
89 llvm::Type *Ty = GV->getValueType();
90 EltTys.emplace_back(Ty);
92 Buf.
LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
97 GlobalVariable *CBGV =
new GlobalVariable(
99 GlobalValue::LinkageTypes::ExternalLinkage,
nullptr,
100 llvm::formatv(
"{0}{1}", Buf.
Name, Buf.
IsCBuffer ?
".cb." :
".tb."),
101 GlobalValue::NotThreadLocal);
103 IRBuilder<> B(CBGV->getContext());
104 Value *ZeroIdx = B.getInt32(0);
106 for (
auto &[GV, Offset] : Buf.
Constants) {
108 B.CreateGEP(Buf.
LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
110 assert(Buf.
LayoutStruct->getElementType(Offset) == GV->getValueType() &&
111 "constant type mismatch");
114 GV->replaceAllUsesWith(GEP);
116 GV->removeDeadConstantUsers();
117 GV->eraseFromParent();
131 llvm_unreachable(
"Generic handling of HLSL types is not supported.");
134llvm::Triple::ArchType CGHLSLRuntime::getArch() {
138void CGHLSLRuntime::addConstant(
VarDecl *
D, Buffer &CB) {
150 codegenoptions::DebugInfoKind::LimitedDebugInfo)
151 DI->EmitGlobalVariable(cast<GlobalVariable>(GV),
D);
156 bool HasUserOffset =
false;
158 unsigned LowerBound = HasUserOffset ? Offset :
UINT_MAX;
159 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
162void CGHLSLRuntime::addBufferDecls(
const DeclContext *DC, Buffer &CB) {
164 if (
auto *ConstDecl = dyn_cast<VarDecl>(it)) {
165 addConstant(ConstDecl, CB);
166 }
else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
168 }
else if (isa<FunctionDecl>(it)) {
177 Buffers.emplace_back(
Buffer(
D));
178 addBufferDecls(
D, Buffers.back());
184 Triple
T(M.getTargetTriple());
185 if (
T.getArch() == Triple::ArchType::dxil)
186 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
190 addDisableOptimizations(M);
192 const DataLayout &DL = M.getDataLayout();
194 for (
auto &Buf : Buffers) {
195 layoutBuffer(Buf, DL);
196 GlobalVariable *GV = replaceBuffer(Buf);
197 M.insertGlobalVariable(GV);
198 llvm::hlsl::ResourceClass RC = Buf.
IsCBuffer
199 ? llvm::hlsl::ResourceClass::CBuffer
200 : llvm::hlsl::ResourceClass::SRV;
201 llvm::hlsl::ResourceKind RK = Buf.
IsCBuffer
202 ? llvm::hlsl::ResourceKind::CBuffer
203 : llvm::hlsl::ResourceKind::TBuffer;
204 addBufferResourceAnnotation(GV, RC, RK,
false,
205 llvm::hlsl::ElementType::Invalid, Buf.
Binding);
210 : Name(
D->
getName()), IsCBuffer(
D->isCBuffer()),
211 Binding(
D->getAttr<HLSLResourceBindingAttr>()) {}
213void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
214 llvm::hlsl::ResourceClass RC,
215 llvm::hlsl::ResourceKind RK,
217 llvm::hlsl::ElementType ET,
218 BufferResBinding &Binding) {
221 NamedMDNode *ResourceMD =
nullptr;
223 case llvm::hlsl::ResourceClass::UAV:
224 ResourceMD = M.getOrInsertNamedMetadata(
"hlsl.uavs");
226 case llvm::hlsl::ResourceClass::SRV:
227 ResourceMD = M.getOrInsertNamedMetadata(
"hlsl.srvs");
229 case llvm::hlsl::ResourceClass::CBuffer:
230 ResourceMD = M.getOrInsertNamedMetadata(
"hlsl.cbufs");
233 assert(
false &&
"Unsupported buffer type!");
236 assert(ResourceMD !=
nullptr &&
237 "ResourceMD must have been set by the switch above.");
239 llvm::hlsl::FrontendResource Res(
240 GV, RK, ET, IsROV, Binding.Reg.value_or(
UINT_MAX), Binding.Space);
241 ResourceMD->addOperand(Res.getMetadata());
244static llvm::hlsl::ElementType
246 using llvm::hlsl::ElementType;
251 assert(TST &&
"Resource types must be template specializations");
253 assert(!Args.empty() &&
"Resource has no element type");
257 QualType ElTy = Args[0].getAsType();
261 ElTy = VecTy->getElementType();
266 return ElementType::I16;
268 return ElementType::I32;
270 return ElementType::I64;
275 return ElementType::U16;
277 return ElementType::U32;
279 return ElementType::U64;
282 return ElementType::F16;
284 return ElementType::F32;
286 return ElementType::F64;
289 llvm_unreachable(
"Invalid element type for resource");
293 const Type *Ty =
D->getType()->getPointeeOrArrayElementType();
301 for (
auto *FD : RD->fields()) {
302 const auto *HLSLResAttr = FD->
getAttr<HLSLResourceAttr>();
304 dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr());
305 if (!HLSLResAttr || !AttrResType)
308 llvm::hlsl::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
309 if (RC == llvm::hlsl::ResourceClass::UAV ||
310 RC == llvm::hlsl::ResourceClass::SRV)
319 bool IsROV = AttrResType->getAttrs().IsROV;
320 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
323 BufferResBinding Binding(
D->
getAttr<HLSLResourceBindingAttr>());
324 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
328CGHLSLRuntime::BufferResBinding::BufferResBinding(
329 HLSLResourceBindingAttr *Binding) {
331 llvm::APInt RegInt(64, 0);
332 Binding->getSlot().substr(1).getAsInteger(10, RegInt);
333 Reg = RegInt.getLimitedValue();
334 llvm::APInt SpaceInt(64, 0);
335 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
336 Space = SpaceInt.getLimitedValue();
344 const auto *ShaderAttr = FD->
getAttr<HLSLShaderAttr>();
345 assert(ShaderAttr &&
"All entry functions must have a HLSLShaderAttr");
346 const StringRef ShaderAttrKindStr =
"hlsl.shader";
347 Fn->addFnAttr(ShaderAttrKindStr,
348 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
349 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->
getAttr<HLSLNumThreadsAttr>()) {
350 const StringRef NumThreadsKindStr =
"hlsl.numthreads";
351 std::string NumThreadsStr =
352 formatv(
"{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
353 NumThreadsAttr->getZ());
354 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
356 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->
getAttr<HLSLWaveSizeAttr>()) {
357 const StringRef WaveSizeKindStr =
"hlsl.wavesize";
358 std::string WaveSizeStr =
359 formatv(
"{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),
360 WaveSizeAttr->getPreferred());
361 Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);
363 Fn->addFnAttr(llvm::Attribute::NoInline);
367 if (
const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
369 for (
unsigned I = 0; I < VT->getNumElements(); ++I) {
370 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
375 return B.CreateCall(F, {B.getInt32(0)});
381 assert(
D.
hasAttrs() &&
"Entry parameter missing annotation attribute!");
382 if (
D.
hasAttr<HLSLSV_GroupIndexAttr>()) {
383 llvm::Function *DxGroupIndex =
385 return B.CreateCall(FunctionCallee(DxGroupIndex));
387 if (
D.
hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
388 llvm::Function *ThreadIDIntrinsic =
392 if (
D.
hasAttr<HLSLSV_GroupThreadIDAttr>()) {
393 llvm::Function *GroupThreadIDIntrinsic =
397 if (
D.
hasAttr<HLSLSV_GroupIDAttr>()) {
398 llvm::Function *GroupIDIntrinsic =
CGM.
getIntrinsic(getGroupIdIntrinsic());
401 assert(
false &&
"Unhandled parameter attribute");
406 llvm::Function *Fn) {
408 llvm::LLVMContext &Ctx = M.getContext();
409 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
false);
411 Function::Create(EntryTy, Function::ExternalLinkage, FD->
getName(), &M);
415 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
416 Fn->getAttributes().getFnAttrs());
417 EntryFn->setAttributes(NewAttrs);
421 Fn->setLinkage(GlobalValue::InternalLinkage);
423 BasicBlock *BB = BasicBlock::Create(Ctx,
"entry", EntryFn);
429 assert(EntryFn->isConvergent());
430 llvm::Value *I = B.CreateIntrinsic(
431 llvm::Intrinsic::experimental_convergence_entry, {}, {});
432 llvm::Value *bundleArgs[] = {I};
433 OB.emplace_back(
"convergencectrl", bundleArgs);
438 unsigned SRetOffset = 0;
439 for (
const auto &Param : Fn->args()) {
440 if (Param.hasStructRetAttr()) {
444 Args.emplace_back(PoisonValue::get(Param.getType()));
451 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
452 CI->setCallingConv(Fn->getCallingConv());
459 llvm::Function *Fn) {
461 const StringRef ExportAttrKindStr =
"hlsl.export";
462 Fn->addFnAttr(ExportAttrKindStr);
469 M.getNamedGlobal(CtorOrDtor ?
"llvm.global_ctors" :
"llvm.global_dtors");
472 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
480 for (
const auto &Ctor : CA->operands()) {
481 if (isa<ConstantAggregateZero>(Ctor))
483 ConstantStruct *CS = cast<ConstantStruct>(Ctor);
485 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
486 "HLSL doesn't support setting priority for global ctors.");
487 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
488 "HLSL doesn't support COMDat for global ctors.");
489 Fns.push_back(cast<Function>(CS->getOperand(1)));
503 for (
auto &F : M.functions()) {
504 if (!F.hasFnAttribute(
"hlsl.shader"))
507 Instruction *IP = &*F.getEntryBlock().begin();
510 llvm::Value *bundleArgs[] = {
Token};
511 OB.emplace_back(
"convergencectrl", bundleArgs);
512 IP =
Token->getNextNode();
515 for (
auto *Fn : CtorFns) {
516 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
517 CI->setCallingConv(Fn->getCallingConv());
521 B.SetInsertPoint(F.back().getTerminator());
522 for (
auto *Fn : DtorFns) {
523 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
524 CI->setCallingConv(Fn->getCallingConv());
530 Triple
T(M.getTargetTriple());
531 if (
T.getEnvironment() != Triple::EnvironmentType::Library) {
532 if (
auto *GV = M.getNamedGlobal(
"llvm.global_ctors"))
533 GV->eraseFromParent();
534 if (
auto *GV = M.getNamedGlobal(
"llvm.global_dtors"))
535 GV->eraseFromParent();
545 llvm::GlobalVariable *GV,
unsigned Slot,
548 llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);
550 llvm::Function *InitResFunc = llvm::Function::Create(
551 llvm::FunctionType::get(
CGM.
VoidTy,
false),
552 llvm::GlobalValue::InternalLinkage,
554 InitResFunc->addFnAttr(llvm::Attribute::AlwaysInline);
556 llvm::BasicBlock *EntryBB =
557 llvm::BasicBlock::Create(Ctx,
"entry", InitResFunc);
560 Builder.SetInsertPoint(EntryBB);
569 assert(AttrResType !=
nullptr &&
570 "Resource class must have a handle of HLSLAttributedResourceType");
572 llvm::Type *TargetTy =
574 assert(TargetTy !=
nullptr &&
575 "Failed to convert resource handle to target type");
577 llvm::Value *Args[] = {
578 llvm::ConstantInt::get(
CGM.
IntTy, Space),
579 llvm::ConstantInt::get(
CGM.
IntTy, Slot),
581 llvm::ConstantInt::get(
CGM.
IntTy, 1),
582 llvm::ConstantInt::get(
CGM.
IntTy, 0),
584 llvm::ConstantInt::get(Int1Ty,
false)
586 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
589 Twine(VD->
getName()).concat(
"_h"));
591 llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0);
592 Builder.CreateAlignedStore(CreateHandle, HandleRef,
593 HandleRef->getPointerAlignment(DL));
594 Builder.CreateRetVoid();
600 llvm::GlobalVariable *GV) {
604 const HLSLResourceBindingAttr *RBA = VD->
getAttr<HLSLResourceBindingAttr>();
617 RBA->getSpaceNumber());
625 for (
auto I = BB.begin(); I !=
E; ++I) {
626 auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
627 if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
631 llvm_unreachable(
"Convergence token should have been emitted.");
static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)
static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)
static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)
static void createResourceInitFn(CodeGenModule &CGM, const VarDecl *VD, llvm::GlobalVariable *GV, unsigned Slot, unsigned Space)
static bool isResourceRecordType(const clang::Type *Ty)
static std::string getName(const CallEvent &Call)
Defines the clang::TargetOptions class.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
This class gathers all debug information during compilation and is responsible for emitting to llvm g...
llvm::Instruction * getConvergenceToken(llvm::BasicBlock &BB)
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)
void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var)
llvm::Type * convertHLSLSpecificType(const Type *T)
llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)
void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)
void addBuffer(const HLSLBufferDecl *D)
void generateGlobalCtorDtorCalls()
This class organizes the cross-function state that is used while generating LLVM code.
CGHLSLRuntime & getHLSLRuntime()
Return a reference to the configured HLSL runtime.
llvm::Module & getModule() const
CGDebugInfo * getModuleDebugInfo()
void AddCXXGlobalInit(llvm::Function *F)
const TargetInfo & getTarget() const
void EmitGlobal(GlobalDecl D)
Emit code for a single global function or var decl.
bool shouldEmitConvergenceTokens() const
ASTContext & getContext() const
llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)
Return the llvm::Constant for the address of the given global variable.
const TargetCodeGenInfo & getTargetCodeGenInfo()
const CodeGenOptions & getCodeGenOpts() const
llvm::LLVMContext & getLLVMContext()
llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys={})
void EmitTopLevelDecl(Decl *D)
Emit code for a single top level declaration.
virtual llvm::Type * getHLSLType(CodeGenModule &CGM, const Type *T) const
Return an LLVM type that corresponds to a HLSL type.
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Decl - This represents one declaration (or definition), e.g.
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
Represents a function declaration or definition.
const ParmVarDecl * getParamDecl(unsigned i) const
static const HLSLAttributedResourceType * findHandleTypeOnResource(const Type *RT)
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Represents a parameter to a function.
A (possibly-)qualified type.
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
TargetOptions & getTargetOpts() const
Retrieve the target options.
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
Represents a type template specialization; the template must be a class template, a type alias templa...
Token - This structure provides full information about a lexed token.
The base class of the type hierarchy.
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
bool isSignedIntegerType() const
Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...
bool isSpecificBuiltinType(unsigned K) const
Test for a particular builtin type.
bool isHLSLSpecificType() const
bool isUnsignedIntegerType() const
Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...
const T * getAs() const
Member-template getAs<specific type>'.
Represents a variable declaration or definition.
Represents a GCC generic vector type.
bool Const(InterpState &S, CodePtr OpPC, const T &Arg)
The JSON file list parser is used to communicate input to InstallAPI.
@ Result
The result type of a method or function.
const FunctionProtoType * T
Diagnostic wrappers for TextAPI types for error reporting.
std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants
llvm::StructType * LayoutStruct
Buffer(const HLSLBufferDecl *D)
llvm::IntegerType * IntTy
int