diff --git a/Compiler.java b/Compiler.java index 6e9e055..35d949d 100755 --- a/Compiler.java +++ b/Compiler.java @@ -7,9 +7,8 @@ import frontend.lexer.TokenStream; import frontend.parser.Parser; import midend.Midend; import midend.errorhandle.ErrorHandler; - +import midend.optimize.Optimize; import error.Errors; -import midend.symbol.SymbolManager; public class Compiler { public static void main(String[] args) { @@ -17,6 +16,7 @@ public class Compiler { try { String content = new String(Files.readAllBytes(Paths.get("testfile.txt"))); String llvmFile = "llvm_ir.txt"; + String llvmOpFile = "llvm_op_ir.txt"; String mipsFile = "mips.txt"; String errorFile = "error.txt"; Lexer lexer = new Lexer(content); @@ -36,10 +36,13 @@ public class Compiler { } else { Midend midend = new Midend(parser.getCompUnit()); midend.generateLLvmIr(); - // midend.writeToFile(llvmFile); - BackEnd backEnd = new BackEnd(midend.getModule()); - backEnd.toMips(); - backEnd.writeToFile(mipsFile); + midend.writeToFile(llvmFile); + Optimize optimize = new Optimize(midend.getModule()); + optimize.run(); + midend.writeToFile(llvmOpFile); + // BackEnd backEnd = new BackEnd(midend.getModule()); + // backEnd.toMips(); + // backEnd.writeToFile(mipsFile); } } catch (Exception e) { e.printStackTrace(); diff --git a/midend/llvm/instr/PhiInstr.java b/midend/llvm/instr/PhiInstr.java new file mode 100644 index 0000000..195242a --- /dev/null +++ b/midend/llvm/instr/PhiInstr.java @@ -0,0 +1,57 @@ +package midend.llvm.instr; + +import java.util.ArrayList; +import java.util.StringJoiner; + +import midend.llvm.type.IrType; +import midend.llvm.value.IrBasicBlock; +import midend.llvm.value.IrValue; + +public class PhiInstr extends IrInstr { + private ArrayList predBBs; + + public PhiInstr(IrType type, IrBasicBlock bblock, String name) { + super(type, name, IrInstrType.PHI); + this.setBBlock(bblock); + this.predBBs = new ArrayList<>(bblock.getPreds()); + for (IrBasicBlock bb : predBBs) { + this.addUse(null); + } + } + + public ArrayList getPredBBs() { + return predBBs; + } + + public void setValueForPred(IrBasicBlock pred, IrValue value) { + int index = predBBs.indexOf(pred); + this.setUse(index, value); + } + + public void deletePred(IrBasicBlock pred) { + int index = predBBs.indexOf(pred); + if (index != -1) { + predBBs.remove(index); + this.deleteUse(index); + } + } + + public void replacePred(IrBasicBlock oldPred, IrBasicBlock newPred) { + //TODO: whether the function is needed + } + + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getName()); + sb.append(" = phi "); + sb.append(getType().toString()); + for (int i = 0; i < this.predBBs.size(); i++) { + sb.append("[ "); + sb.append(getUse(i).getName()); + sb.append(", %"); + sb.append(this.predBBs.get(i).getName()); + sb.append(" ]"); + } + return sb.toString(); + } +} diff --git a/midend/llvm/use/IrUser.java b/midend/llvm/use/IrUser.java index c699e5b..15a7b10 100755 --- a/midend/llvm/use/IrUser.java +++ b/midend/llvm/use/IrUser.java @@ -18,11 +18,10 @@ public class IrUser extends IrValue { } public void addUse(IrValue value) { - if (value == null) { - return; - } uses.add(value); - value.addUser(this); + if (value != null) { + value.addUser(this); + } } public IrValue getUse(int index) { @@ -32,6 +31,44 @@ public class IrUser extends IrValue { return uses.get(index); } + public void setUse(int index, IrValue value) { + if (index >= uses.size() || index < 0) { + return; + } + IrValue oldValue = uses.get(index); + if (oldValue != null) { + oldValue.deleteUser(this); + } + uses.set(index, value); + if (value != null) { + value.addUser(this); + } + } + + public void deleteUse(IrValue value) { + uses.remove(value); + if (value != null) { + value.deleteUser(this); + } + } + + public void clearUses() { + for (IrValue value : uses) { + if (value != null) { + value.deleteUser(this); + } + } + uses.clear(); + } + + public void deleteUse(int index) { + if (index >= uses.size() || index < 0) { + return; + } + IrValue value = uses.get(index); + deleteUse(value); + } + public int getNumUses() { return uses.size(); } diff --git a/midend/llvm/value/IrBasicBlock.java b/midend/llvm/value/IrBasicBlock.java index 17ca7b0..df57b6c 100755 --- a/midend/llvm/value/IrBasicBlock.java +++ b/midend/llvm/value/IrBasicBlock.java @@ -35,6 +35,27 @@ public class IrBasicBlock extends IrValue { instr.setBBlock(this); } + public void addInstr(IrInstr instr, int index) { + instrs.add(index, instr); + instr.setBBlock(this); + } + + public void deleteInstr(IrInstr instr) { + instrs.remove(instr); + instr.setBBlock(null); + instr.clearUses(); + instr.clearUsers(); + } + + public void clearAllInstrs() { + for (IrInstr instr : instrs) { + instr.setBBlock(null); + instr.clearUses(); + instr.clearUsers(); + } + instrs.clear(); + } + public IrFuncValue getFunc() { return func; } @@ -87,6 +108,20 @@ public class IrBasicBlock extends IrValue { return instrs; } + public IrInstr getFirstInstr() { + if (instrs.isEmpty()) { + return null; + } + return instrs.get(0); + } + + public IrInstr getLastInstr() { + if (instrs.isEmpty()) { + return null; + } + return instrs.get(instrs.size() - 1); + } + public void addPred(IrBasicBlock bb) { this.preds.add(bb); } @@ -119,6 +154,14 @@ public class IrBasicBlock extends IrValue { return this.directDomi; } + public HashSet getDirectDomies() { + return directDomies; + } + + public HashSet getDomiFrontier() { + return domiFrontier; + } + public void toMips() { new MipsLabel(getMipsLabel()); for (IrInstr instr : instrs) { diff --git a/midend/llvm/value/IrFuncValue.java b/midend/llvm/value/IrFuncValue.java index 60a60d6..610bca4 100755 --- a/midend/llvm/value/IrFuncValue.java +++ b/midend/llvm/value/IrFuncValue.java @@ -45,6 +45,27 @@ public class IrFuncValue extends IrValue { bblocks.add(bblock); } + public void deleteBBlock(IrBasicBlock bblock) { + bblocks.remove(bblock); + bblock.clearAllInstrs(); + } + + public void deleteDeadBlock() { + ArrayList liveBlocks = new ArrayList<>(); + ArrayList deadBlocks = new ArrayList<>(); + for (IrBasicBlock bb : bblocks) { + if (!bb.getPreds().isEmpty() || bb.isEntry()) { + liveBlocks.add(bb); + } else { + deadBlocks.add(bb); + } + } + bblocks = liveBlocks; + for (IrBasicBlock bb : deadBlocks) { + bb.clearAllInstrs(); + } + } + public Register getRegister(IrValue value) { return valueRegisterMap.get(value); } diff --git a/midend/llvm/value/IrValue.java b/midend/llvm/value/IrValue.java index f365032..dabc551 100755 --- a/midend/llvm/value/IrValue.java +++ b/midend/llvm/value/IrValue.java @@ -32,6 +32,28 @@ public class IrValue { users.add(user); } + public void deleteUser(IrUser user) { + users.remove(user); + } + + public void clearUsers() { + ArrayList usersCopy = new ArrayList<>(users); + for (IrUser user : usersCopy) { + user.deleteUse(this); + } + } + + public void replaceUserToAnother(IrValue newValue) { + ArrayList usersCopy = new ArrayList<>(users); + for (IrUser user : usersCopy) { + for (int i = 0; i < user.getUses().size(); i++) { + if (user.getUse(i) == this) { + user.setUse(i, newValue); + } + } + } + } + public String toString() { return type.toString() + " " + name; } diff --git a/midend/optimize/CfgMake.java b/midend/optimize/CfgMake.java index 6ab6f2a..ef43d03 100755 --- a/midend/optimize/CfgMake.java +++ b/midend/optimize/CfgMake.java @@ -1,6 +1,7 @@ package midend.optimize; import java.util.HashSet; +import java.util.ArrayList; import midend.llvm.instr.IrInstr; import midend.llvm.instr.IrInstrType; @@ -16,7 +17,81 @@ public class CfgMake extends Optimizer { bb.clearCfg(); } } + deleteUselessInstrs(); + deleteUnreachedBlocks(); makeCfg(); + for (IrFuncValue func : getIrModule().getFuncs()) { + func.deleteDeadBlock(); + } + makeDomination(); + makeDirectDomi(); + makeDomiFrontier(); + } + + public void deleteUselessInstrs() { + for (IrFuncValue func : getIrModule().getFuncs()) { + for (IrBasicBlock bb : func.getBBlocks()) { + ArrayList toDelete = new ArrayList<>(); + boolean deleteUseless = false; + for (IrInstr instr : bb.getInstrs()) { + if (!deleteUseless) { + if (instr.getInstrType() == IrInstrType.RET || + instr.getInstrType() == IrInstrType.BR || + instr.getInstrType() == IrInstrType.JUMP) { + deleteUseless = true; + } + } else { + toDelete.add(instr); + } + } + for (IrInstr instr : toDelete) { + bb.deleteInstr(instr); + } + } + } + } + + public void deleteUnreachedBlocks() { + for (IrFuncValue func : getIrModule().getFuncs()) { + HashSet canArrive = new HashSet<>(); + IrBasicBlock entryBlock = func.getBBlock(0); + dfsArrive(entryBlock, canArrive); + ArrayList toDelete = new ArrayList<>(); + for (IrBasicBlock bb : func.getBBlocks()) { + if (!canArrive.contains(bb)) { + toDelete.add(bb); + } + } + for (IrBasicBlock bb : toDelete) { + func.deleteBBlock(bb); + } + } + } + + public void dfsArrive(IrBasicBlock bb, HashSet canArrive) { + if (!canArrive.contains(bb)) { + canArrive.add(bb); + } + IrInstr lastInstr = bb.getLastInstr(); + if (lastInstr != null) { + if (lastInstr.getInstrType() == IrInstrType.BR) { + BranchInstr branchInstr = (BranchInstr) lastInstr; + IrBasicBlock trueBlock = branchInstr.getTrueBB(); + IrBasicBlock falseBlock = branchInstr.getFalseBB(); + if (!canArrive.contains(trueBlock)) { + dfsArrive(trueBlock, canArrive); + } + if (!canArrive.contains(falseBlock)) { + dfsArrive(falseBlock, canArrive); + } + } else if (lastInstr.getInstrType() == IrInstrType.JUMP) { + JumpInstr jumpInstr = (JumpInstr) lastInstr; + IrBasicBlock jumpBlock = jumpInstr.getTargetBlock(); + if (!canArrive.contains(jumpBlock)) { + dfsArrive(jumpBlock, canArrive); + } + } + } } public void makeCfg() { @@ -76,8 +151,8 @@ public class CfgMake extends Optimizer { for (IrFuncValue func : getIrModule().getFuncs()) { for (IrBasicBlock bb : func.getBBlocks()) { for (IrBasicBlock domi : bb.getDomied()) { - HashSet bbDomi = bb.getDomied(); - HashSet domiDomi = domi.getDomied(); + HashSet bbDomi = new HashSet<>(bb.getDomied()); + HashSet domiDomi = new HashSet<>(domi.getDomied()); for (IrBasicBlock domiDomiBB : domiDomi) { if (bbDomi.contains(domiDomiBB)) { bbDomi.remove(domiDomiBB); diff --git a/midend/optimize/DeleteDead.java b/midend/optimize/DeleteDead.java new file mode 100644 index 0000000..9c800cd --- /dev/null +++ b/midend/optimize/DeleteDead.java @@ -0,0 +1,20 @@ +package midend.optimize; + +import java.util.HashSet; + +import midend.llvm.value.IrFuncValue; + +public class DeleteDead extends Optimizer { + public void optimize() { + + } + + public void deleteDeadFunc() { + IrFuncValue mainFunc = getIrModule().getMainFunc(); + HashSet liveFuncs = new HashSet<>(); + liveFuncs.add(mainFunc); + findLiveFuncs(mainFunc, liveFuncs); + } + + public void findLiveFuncs() +} diff --git a/midend/optimize/MemToReg.java b/midend/optimize/MemToReg.java new file mode 100644 index 0000000..ac81123 --- /dev/null +++ b/midend/optimize/MemToReg.java @@ -0,0 +1,36 @@ +package midend.optimize; + +import midend.llvm.value.IrBasicBlock; +import midend.llvm.value.IrFuncValue; + +import java.util.ArrayList; + +import midend.llvm.instr.AllocateInstr; +import midend.llvm.instr.IrInstr; +import midend.llvm.type.IrArrayType; + +public class MemToReg extends Optimizer { + public void optimize() { + for (IrFuncValue func : getIrModule().getFuncs()) { + IrBasicBlock entryBlock = func.getBBlock(0); + for (IrBasicBlock block : func.getBBlocks()) { + ArrayList instrs = new ArrayList<>(block.getInstrs()); + for (IrInstr instr : instrs) { + if (normalAlloca(instr)) { + AllocateInstr allocInstr = (AllocateInstr) instr; + PhiInsert phiInsert = new PhiInsert(allocInstr, entryBlock); + phiInsert.run(); + } + } + } + } + } + + public boolean normalAlloca(IrInstr instr) { + if (!(instr instanceof AllocateInstr)) { + return false; + } + AllocateInstr allocInstr = (AllocateInstr) instr; + return !(allocInstr.getPointeeType() instanceof IrArrayType); + } +} diff --git a/midend/optimize/Optimize.java b/midend/optimize/Optimize.java new file mode 100644 index 0000000..5bfe2e5 --- /dev/null +++ b/midend/optimize/Optimize.java @@ -0,0 +1,23 @@ +package midend.optimize; + +import java.util.ArrayList; + +import midend.llvm.IrModule; + +public class Optimize { + private ArrayList optimizers; + + public Optimize(IrModule module) { + Optimizer.setIrModule(module); + optimizers = new ArrayList<>(); + optimizers.add(new CfgMake()); + optimizers.add(new MemToReg()); + optimizers.add(new CfgMake()); + } + + public void run() { + for (Optimizer optimizer : optimizers) { + optimizer.optimize(); + } + } +} diff --git a/midend/optimize/Optimizer.java b/midend/optimize/Optimizer.java index abc317f..9a65ba8 100755 --- a/midend/optimize/Optimizer.java +++ b/midend/optimize/Optimizer.java @@ -5,7 +5,7 @@ import midend.llvm.IrModule; public class Optimizer { private static IrModule irModule; - public void setIrModule(IrModule irModule) { + public static void setIrModule(IrModule irModule) { Optimizer.irModule = irModule; } diff --git a/midend/optimize/PhiInsert.java b/midend/optimize/PhiInsert.java new file mode 100644 index 0000000..27bdac6 --- /dev/null +++ b/midend/optimize/PhiInsert.java @@ -0,0 +1,133 @@ +package midend.optimize; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Stack; + +import midend.llvm.IrBuilder; +import midend.llvm.constant.IrConstantInt; +import midend.llvm.instr.AllocateInstr; +import midend.llvm.instr.IrInstr; +import midend.llvm.instr.StoreInstr; +import midend.llvm.instr.LoadInstr; +import midend.llvm.instr.PhiInstr; +import midend.llvm.use.IrUser; +import midend.llvm.value.IrBasicBlock; +import midend.llvm.value.IrValue; + +public class PhiInsert { + private ArrayList defineInstrs; + private ArrayList useInstrs; + private AllocateInstr allocInstr; + private IrBasicBlock entryBlock; + private ArrayList defBlocks; + private ArrayList useBlocks; + private Stack workList; + + public PhiInsert(AllocateInstr allocInstr, IrBasicBlock entryBlock) { + this.defineInstrs = new ArrayList<>(); + this.useInstrs = new ArrayList<>(); + this.allocInstr = allocInstr; + this.entryBlock = entryBlock; + this.defBlocks = new ArrayList<>(); + this.useBlocks = new ArrayList<>(); + this.workList = new Stack<>(); + } + + public IrValue getNewValue() { + if (workList.isEmpty()) { + return new IrConstantInt(0); + } + return workList.peek(); + } + + public void run() { + makeDefAndUse(); + insertPhiNodes(); + changeLoadAndStoreToReg(entryBlock); + } + + public void makeDefAndUse() { + ArrayList users = new ArrayList<>(allocInstr.getUsers()); + for (IrUser user : users) { + IrInstr instr = (IrInstr) user; + if (instr instanceof StoreInstr) { + defineInstrs.add(instr); + IrBasicBlock defBlock = instr.getBBlock(); + if (!defBlocks.contains(defBlock)) { + defBlocks.add(defBlock); + } + } else if (instr instanceof LoadInstr){ + useInstrs.add(instr); + IrBasicBlock useBlock = instr.getBBlock(); + if (!useBlocks.contains(useBlock)) { + useBlocks.add(useBlock); + } + } + } + } + + public void insertPhiNodes() { + HashSet hasPhiBlocks = new HashSet<>(); + Stack defStack = new Stack<>(); + for (IrBasicBlock defBlock : defBlocks) { + defStack.push(defBlock); + } + while (!defStack.isEmpty()) { + IrBasicBlock currBlock = defStack.pop(); + for (IrBasicBlock frontierBlock : currBlock.getDomiFrontier()) { + if (hasPhiBlocks.contains(frontierBlock)) { + continue; + } + //插入phi节点 + insertPhiNode(frontierBlock); + hasPhiBlocks.add(frontierBlock); + if (!defBlocks.contains(frontierBlock)) { + defStack.push(frontierBlock); + } + } + } + } + + public void insertPhiNode(IrBasicBlock block) { + PhiInstr phiInstr = new PhiInstr(allocInstr.getPointeeType(), block, IrBuilder.getLocalName(block.getFunc())); + block.addInstr(phiInstr, 0); + defineInstrs.add(phiInstr); + useInstrs.add(phiInstr); + } + + public void changeLoadAndStoreToReg(IrBasicBlock block) { + Stack workListCopy = new Stack<>(); + workListCopy.addAll(workList); + ArrayList deleteInstrs = new ArrayList<>(); + for(IrInstr instr : block.getInstrs()) { + if (instr instanceof LoadInstr && useInstrs.contains(instr)) { + instr.replaceUserToAnother(getNewValue()); + deleteInstrs.add(instr); + } else if (instr instanceof StoreInstr && defineInstrs.contains(instr)) { + StoreInstr storeInstr = (StoreInstr) instr; + IrValue value = storeInstr.getValue(); + workList.push(value); + deleteInstrs.add(instr); + } else if (instr instanceof PhiInstr && defineInstrs.contains(instr)) { + workList.push(instr); + } else if (instr == allocInstr) { + deleteInstrs.add(instr); + } + } + for (IrInstr instr : deleteInstrs) { + block.deleteInstr(instr); + } + for (IrBasicBlock succ : block.getSuccs()) { + IrInstr firstInstr = succ.getFirstInstr(); + if (firstInstr instanceof PhiInstr && useInstrs.contains(firstInstr)) { + PhiInstr phiInstr = (PhiInstr) firstInstr; + phiInstr.setValueForPred(block, getNewValue()); + } + } + for (IrBasicBlock domiBB : block.getDirectDomies()) { + changeLoadAndStoreToReg(domiBB); + } + workList = workListCopy; + } +}