Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTX backend implementation for HAT #188

Closed
wants to merge 11 commits into from
29 changes: 11 additions & 18 deletions hat/backends/ptx/src/main/java/hat/backend/PTXBackend.java
Original file line number Diff line number Diff line change
@@ -77,8 +77,6 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, NDRange ndRange, Obj

System.out.println("Entrypoint ->"+kernelCallGraph.entrypoint.method.getName());
String code = createCode(kernelCallGraph, new PTXCodeBuilder(), args);
// System.out.println("\nCode Builder Output: \n\n" + code);
// System.out.println("Add your code to "+PTXBackend.class.getName()+".dispatchKernel() to actually run! :)");
long programHandle = compileProgram(code);
if (programOK(programHandle)) {
long kernelHandle = getKernel(programHandle, kernelCallGraph.entrypoint.method.getName());
@@ -88,9 +86,8 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, NDRange ndRange, Obj
}

public String createCode(KernelCallGraph kernelCallGraph, PTXCodeBuilder builder, Object[] args) {
String out = "";
Optional<CoreOp.FuncOp> o = Optional.ofNullable(kernelCallGraph.entrypoint.funcOpWrapper().op());
FuncOpWrapper f = new FuncOpWrapper(o.orElseThrow());
StringBuilder out = new StringBuilder();
FuncOpWrapper f = new FuncOpWrapper(kernelCallGraph.entrypoint.funcOpWrapper().op());
FuncOpWrapper lowered = f.lower();
HashMap<String, Object> argsMap = new HashMap<>();
for (int i = 0; i < args.length; i++) {
@@ -102,25 +99,20 @@ public String createCode(KernelCallGraph kernelCallGraph, PTXCodeBuilder builder

// printing out ptx header (device info)
builder.ptxHeader(major, minor, target, addressSize);
out += builder.getTextAndReset();
out.append(builder.getTextAndReset());

for (KernelCallGraph.KernelReachableResolvedMethodCall k : kernelCallGraph.kernelReachableResolvedStream().toList()) {
Optional<CoreOp.FuncOp> optional = Optional.ofNullable(k.funcOpWrapper().op());
FuncOpWrapper calledFunc = new FuncOpWrapper(optional.orElseThrow());
FuncOpWrapper calledFunc = new FuncOpWrapper(k.funcOpWrapper().op());
FuncOpWrapper loweredFunc = calledFunc.lower();
// System.out.println("------------func------------");
// System.out.println(loweredFunc.ssa().toText());
if (useSchema) loweredFunc = transformPtrs(loweredFunc, argsMap);
out += createFunction(new PTXCodeBuilder(addressSize).nl().nl(), loweredFunc, loweredFunc.ssa(), out, false);
out.append(createFunction(new PTXCodeBuilder(addressSize).nl().nl(), loweredFunc, false));
}

if (useSchema) lowered = transformPtrs(lowered, argsMap);
FuncOpWrapper ssa = lowered.ssa();
// System.out.println(lowered.toText());
// System.out.println(ssa.toText());
out += createFunction(builder.nl().nl(), lowered, ssa, out, true);

return out;
out.append(createFunction(builder.nl().nl(), lowered, true));

return out.toString();
}

public FuncOpWrapper transformPtrs(FuncOpWrapper func, HashMap<String, Object> argsMap) {
@@ -148,8 +140,9 @@ public FuncOpWrapper transformPtrs(FuncOpWrapper func, HashMap<String, Object> a
}));
}

public String createFunction(PTXCodeBuilder builder, FuncOpWrapper lowered, FuncOpWrapper ssa, String out, boolean entry) {
String body = "";
public String createFunction(PTXCodeBuilder builder, FuncOpWrapper lowered, boolean entry) {
FuncOpWrapper ssa = lowered.ssa();
String out, body;

// building fn info (name, params)
builder.functionHeader(lowered.functionName(), entry, lowered.op().body().yieldType());
Loading