Files
MY_COMPILER/midend/llvm/instr/CallInstr.java
2025-12-12 20:14:00 +08:00

118 lines
4.3 KiB
Java
Executable File

package midend.llvm.instr;
import java.util.ArrayList;
import backend.mips.Register;
import backend.mips.instr.MipsAlu;
import backend.mips.instr.MipsJump;
import backend.mips.instr.MipsLs;
import backend.mips.instr.type.MipsAluType;
import backend.mips.instr.type.MipsJumpType;
import backend.mips.instr.type.MipsLsType;
import backend.mips.MipsBuilder;
import midend.llvm.value.IrFuncValue;
import midend.llvm.value.IrValue;
public class CallInstr extends IrInstr {
public CallInstr(String name, IrFuncValue func, ArrayList<IrValue> args) {
super(func.getRetType(), name, IrInstrType.CALL);
addUse(func);
for (IrValue arg : args) {
addUse(arg);
}
}
public boolean callVoid() {
return getType().isVoid();
}
public IrFuncValue getCalledFunc() {
return (IrFuncValue) getUse(0);
}
public ArrayList<IrValue> getArgs() {
ArrayList<IrValue> args = new ArrayList<>();
for (int i = 1; i < getNumUses(); i++) {
args.add(getUse(i));
}
return args;
}
public String toString() {
StringBuilder sb = new StringBuilder();
if (!callVoid()) {
sb.append(getName() + " = ");
}
sb.append("call " + getType() + " " + getCalledFunc().getName() + "(");
for (int i = 1; i < getNumUses(); i++) {
sb.append(getUse(i).getType() + " " + getUse(i).getName());
if (i < getNumUses() - 1) {
sb.append(", ");
}
}
sb.append(")");
return sb.toString();
}
public void toMips() {
ArrayList<Register> usedRegisters = MipsBuilder.getUsedRegisters();
int offset = MipsBuilder.getOffset();
save(usedRegisters, offset);
ArrayList<IrValue> args = getArgs();
saveArgs(args, offset, usedRegisters);
offset -= (usedRegisters.size() + 2) * 4;
new MipsAlu(MipsAluType.ADDI, Register.SP, Register.SP, offset);
new MipsJump(MipsJumpType.JAL, getCalledFunc().getMipsLabel());
offset += (usedRegisters.size() + 2) * 4;
recover(usedRegisters, offset);
saveResult(this, Register.V0);
}
public void save(ArrayList<Register> usedRegisters, int offset) {
int num = 0;
for (Register reg : usedRegisters) {
num++;
new MipsLs(MipsLsType.SW, reg, Register.SP, offset - num * 4);
}
new MipsLs(MipsLsType.SW, Register.SP, Register.SP, offset - (num + 1) * 4);
new MipsLs(MipsLsType.SW, Register.RA, Register.SP, offset - (num + 2) * 4);
}
public void saveArgs(ArrayList<IrValue> args, int offset, ArrayList<Register> usedRegisters) {
int num = 0;
ArrayList<Register> argRegs = new ArrayList<>();
argRegs.add(Register.A1);
argRegs.add(Register.A2);
argRegs.add(Register.A3);
for (IrValue arg : args) {
num++;
if (num <= 3) { // 分配到A1、A2、A3
if (argRegs.contains(MipsBuilder.getRegister(arg))) {
int index = usedRegisters.indexOf(MipsBuilder.getRegister(arg));
new MipsLs(MipsLsType.LW, argRegs.get(num - 1), Register.SP, offset - (index + 1) * 4);
} else {
loadValueToReg(arg, argRegs.get(num - 1));
}
} else {
if (MipsBuilder.getRegister(arg) == Register.K0 || argRegs.contains(MipsBuilder.getRegister(arg))) {
int index = usedRegisters.indexOf(MipsBuilder.getRegister(arg));
new MipsLs(MipsLsType.LW, Register.K0, Register.SP, offset - (index + 1) * 4);
} else {
loadValueToReg(arg, Register.K0);
}
new MipsLs(MipsLsType.SW, Register.K0, Register.SP, offset - (usedRegisters.size() + num + 2) * 4);
}
}
}
public void recover(ArrayList<Register> usedRegisters, int offset) {
new MipsLs(MipsLsType.LW, Register.RA, Register.SP, 0);
new MipsLs(MipsLsType.LW, Register.SP, Register.SP, 4);
int num = 0;
for (Register reg : usedRegisters) {
num++;
new MipsLs(MipsLsType.LW, reg, Register.SP, offset - num * 4);
}
}
}