diff --git a/liquidjava-example/src/main/java/testSuite/classes/imagewrite_correct/ImageWriteParamsRefinements.java b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_correct/ImageWriteParamsRefinements.java new file mode 100644 index 00000000..32491265 --- /dev/null +++ b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_correct/ImageWriteParamsRefinements.java @@ -0,0 +1,88 @@ +package testSuite.classes.imagewrite_correct; + +import java.util.Locale; + +import javax.imageio.ImageWriteParam; + +import liquidjava.specification.ExternalRefinementsFor; +import liquidjava.specification.Refinement; +import liquidjava.specification.StateRefinement; +import liquidjava.specification.StateSet; + +/** + * External typestate specification for {@code javax.imageio.ImageWriteParam}. + * + *

+ * The class is modelled as two independent ghost-state dimensions — tiling and compression — so a configuration error + * in one dimension does not mask the other. The conditional {@code setTilingMode} / {@code setCompressionMode} + * transitions only reach the {@code *Explicit} state when called with {@code MODE_EXPLICIT}; any other mode leaves the + * param in its {@code start*} state, which the dimension-specific setters reject. + */ +@StateSet({ "startTiling", "tilingExplicit", "tilingSet" }) +@StateSet({ "startCompression", "compressionExplicit", "compressionSet" }) +@ExternalRefinementsFor("javax.imageio.ImageWriteParam") +public interface ImageWriteParamsRefinements { + + // Constructor + @StateRefinement(to = "startTiling(this) && startCompression(this)") + void ImageWriteParam(Locale locale); + + // Tiling related methods + + @StateRefinement(to = "(mode == ImageWriteParam.MODE_EXPLICIT)? tilingExplicit(this) : startTiling(this)") + void setTilingMode(int mode); + + @StateRefinement(from = "tilingExplicit(this)", to = "tilingSet(this)") + @StateRefinement(from = "tilingSet(this)", to = "tilingSet(this)") + void setTiling(@Refinement("_ > 0") int tileWidth, @Refinement("_ > 0") int tileHeight, int tileGridXOffset, + int tileGridYOffset); + + @StateRefinement(from = "tilingSet(this)") + int getTileGridXOffset(); + + @StateRefinement(from = "tilingSet(this)") + int getTileGridYOffset(); + + @StateRefinement(from = "tilingSet(this)") + int getTileHeight(); + + @StateRefinement(from = "tilingSet(this)") + int getTileWidth(); + + @StateRefinement(from = "tilingExplicit(this)") + @StateRefinement(from = "tilingSet(this)", to = "tilingExplicit(this)") + void unsetTiling(); + + void setProgressiveMode(@Refinement("ImageWriteParam.MODE_DISABLED == mode || mode == ImageWriteParam.MODE_DEFAULT || mode == ImageWriteParam.MODE_COPY_FROM_METADATA") int mode); + + // Compression related methods + + @StateRefinement(to = "mode == ImageWriteParam.MODE_EXPLICIT? compressionExplicit(this) : startCompression(this)") + void setCompressionMode(int mode); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + void setCompressionQuality(@Refinement("_ >= 0.0 && _ <= 1.0") float quality); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + String getCompressionType(); + + @StateRefinement(from = "compressionExplicit(this)", to = "compressionSet(this)") + void setCompressionType(String compressionType); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)", to = "compressionExplicit(this)") + void unsetCompression(); + + @StateRefinement(from = "compressionSet(this)") + String getLocalizedCompressionTypeName(); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + boolean isCompressionLossless(); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + float getCompressionQuality(); +} diff --git a/liquidjava-example/src/main/java/testSuite/classes/imagewrite_correct/JpegExporter.java b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_correct/JpegExporter.java new file mode 100644 index 00000000..e7d6b657 --- /dev/null +++ b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_correct/JpegExporter.java @@ -0,0 +1,32 @@ +package testSuite.classes.imagewrite_correct; + +import java.util.Locale; + +import javax.imageio.ImageWriteParam; + +/** + * A JPEG export pipeline configured correctly against {@link ImageWriteParamsRefinements}. + * + *

+ * Both ghost-state dimensions are driven through their full transition path: each dimension's mode is set to + * {@code MODE_EXPLICIT} before the dimension-specific setters run, and {@code getTileWidth} is reached only after + * {@code setTiling} has moved the param into {@code tilingSet}. No state refinement is violated. + */ +class JpegExporter { + + ImageWriteParam buildJpegParam() { + ImageWriteParam param = new ImageWriteParam(Locale.getDefault()); + param.setTilingMode(ImageWriteParam.MODE_EXPLICIT); + param.setTiling(10, 30, 10, 30); + param.setCompressionMode(ImageWriteParam.MODE_EXPLICIT); + param.setCompressionQuality(0.85f); + return param; + } + + int firstTileWidth() { + ImageWriteParam param = new ImageWriteParam(Locale.getDefault()); + param.setTilingMode(ImageWriteParam.MODE_EXPLICIT); + param.setTiling(8, 8, 0, 0); + return param.getTileWidth(); + } +} diff --git a/liquidjava-example/src/main/java/testSuite/classes/imagewrite_error/ImageWriteParamsRefinements.java b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_error/ImageWriteParamsRefinements.java new file mode 100644 index 00000000..23f3055b --- /dev/null +++ b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_error/ImageWriteParamsRefinements.java @@ -0,0 +1,88 @@ +package testSuite.classes.imagewrite_error; + +import java.util.Locale; + +import javax.imageio.ImageWriteParam; + +import liquidjava.specification.ExternalRefinementsFor; +import liquidjava.specification.Refinement; +import liquidjava.specification.StateRefinement; +import liquidjava.specification.StateSet; + +/** + * External typestate specification for {@code javax.imageio.ImageWriteParam}. + * + *

+ * The class is modelled as two independent ghost-state dimensions — tiling and compression — so a configuration error + * in one dimension does not mask the other. The conditional {@code setTilingMode} / {@code setCompressionMode} + * transitions only reach the {@code *Explicit} state when called with {@code MODE_EXPLICIT}; any other mode leaves the + * param in its {@code start*} state, which the dimension-specific setters reject. + */ +@StateSet({ "startTiling", "tilingExplicit", "tilingSet" }) +@StateSet({ "startCompression", "compressionExplicit", "compressionSet" }) +@ExternalRefinementsFor("javax.imageio.ImageWriteParam") +public interface ImageWriteParamsRefinements { + + // Constructor + @StateRefinement(to = "startTiling(this) && startCompression(this)") + void ImageWriteParam(Locale locale); + + // Tiling related methods + + @StateRefinement(to = "(mode == ImageWriteParam.MODE_EXPLICIT)? tilingExplicit(this) : startTiling(this)") + void setTilingMode(int mode); + + @StateRefinement(from = "tilingExplicit(this)", to = "tilingSet(this)") + @StateRefinement(from = "tilingSet(this)", to = "tilingSet(this)") + void setTiling(@Refinement("_ > 0") int tileWidth, @Refinement("_ > 0") int tileHeight, int tileGridXOffset, + int tileGridYOffset); + + @StateRefinement(from = "tilingSet(this)") + int getTileGridXOffset(); + + @StateRefinement(from = "tilingSet(this)") + int getTileGridYOffset(); + + @StateRefinement(from = "tilingSet(this)") + int getTileHeight(); + + @StateRefinement(from = "tilingSet(this)") + int getTileWidth(); + + @StateRefinement(from = "tilingExplicit(this)") + @StateRefinement(from = "tilingSet(this)", to = "tilingExplicit(this)") + void unsetTiling(); + + void setProgressiveMode(@Refinement("ImageWriteParam.MODE_DISABLED == mode || mode == ImageWriteParam.MODE_DEFAULT || mode == ImageWriteParam.MODE_COPY_FROM_METADATA") int mode); + + // Compression related methods + + @StateRefinement(to = "mode == ImageWriteParam.MODE_EXPLICIT? compressionExplicit(this) : startCompression(this)") + void setCompressionMode(int mode); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + void setCompressionQuality(@Refinement("_ >= 0.0 && _ <= 1.0") float quality); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + String getCompressionType(); + + @StateRefinement(from = "compressionExplicit(this)", to = "compressionSet(this)") + void setCompressionType(String compressionType); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)", to = "compressionExplicit(this)") + void unsetCompression(); + + @StateRefinement(from = "compressionSet(this)") + String getLocalizedCompressionTypeName(); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + boolean isCompressionLossless(); + + @StateRefinement(from = "compressionExplicit(this)") + @StateRefinement(from = "compressionSet(this)") + float getCompressionQuality(); +} diff --git a/liquidjava-example/src/main/java/testSuite/classes/imagewrite_error/JpegExporter.java b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_error/JpegExporter.java new file mode 100644 index 00000000..c5dae5e3 --- /dev/null +++ b/liquidjava-example/src/main/java/testSuite/classes/imagewrite_error/JpegExporter.java @@ -0,0 +1,29 @@ +package testSuite.classes.imagewrite_error; + +import java.util.Locale; + +import javax.imageio.ImageWriteParam; + +/** + * A JPEG export pipeline configured against {@link ImageWriteParamsRefinements}. + * + *

+ * The author did configure a tiling mode — but passed {@code MODE_DEFAULT} instead of {@code MODE_EXPLICIT}. The spec's + * conditional transition leaves the param in {@code startTiling} for any non-explicit mode, so {@code setTiling} + * (which requires {@code tilingExplicit} or {@code tilingSet}) violates its from-state. + * + *

+ * The found-state threads the same {@code param} across SSA versions joined by internal {@code stateN(x) == stateN(y)} + * equalities; state derivation rewrites those into developer-facing typestate names for the diagnostic. + */ +class JpegExporter { + + ImageWriteParam buildJpegParam() { + ImageWriteParam param = new ImageWriteParam(Locale.getDefault()); + param.setTilingMode(ImageWriteParam.MODE_DEFAULT); + param.setCompressionMode(ImageWriteParam.MODE_EXPLICIT); + param.setCompressionQuality(0.85f); + param.setTiling(10, 30, 10, 30); // State Refinement Error + return param; + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java index d01b03e5..f3f23460 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java +++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java @@ -391,6 +391,9 @@ protected void throwStateRefinementError(SourcePosition position, Predicate foun gatherVariables(found, lrv, mainVars); TranslationTable map = new TranslationTable(); VCImplication foundState = joinPredicates(found, mainVars, lrv, map); + // simplify(context) folds the found-state predicate and, in the same pass, rewrites internal + // ghost-state equalities into developer-facing state predicates (see ExpressionSimplifier / + // StateDerivation). The resulting ValDerivationNode keeps the provenance of that rewrite. throw new StateRefinementError(position, expected.simplify(context), foundState.toConjunctions().simplify(context), map, customMessage); } diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java index 64da35b2..095dcded 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java @@ -257,8 +257,9 @@ public ValDerivationNode simplify(Context context) { for (AliasWrapper aw : context.getAliases()) { aliases.put(aw.getName(), aw.createAliasDTO()); } - // simplify expression - ValDerivationNode result = ExpressionSimplifier.simplify(exp.clone(), aliases); + // simplify expression — ghost states let the simplifier rewrite internal state equalities into + // developer-facing state predicates for error messages + ValDerivationNode result = ExpressionSimplifier.simplify(exp.clone(), aliases, context.getGhostStates()); DebugLog.simplification(this.getExpression(), result.getValue()); return result; } diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java index e94d2574..be4da893 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java @@ -2,7 +2,9 @@ import liquidjava.diagnostics.DebugLog; import liquidjava.processor.context.Context; +import liquidjava.processor.context.GhostState; import liquidjava.rj_language.Predicate; +import java.util.List; import java.util.Map; import liquidjava.processor.facade.AliasDTO; @@ -23,9 +25,10 @@ public class ExpressionSimplifier { * Simplifies an expression by applying constant propagation, constant folding, removing redundant conjuncts and * expanding aliases Returns a derivation node representing the tree of simplifications applied */ - public static ValDerivationNode simplify(Expression exp, Map aliases) { + public static ValDerivationNode simplify(Expression exp, Map aliases, + List ghostStates) { DebugLog.simplificationStart(exp); - ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp); + ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp, ghostStates); DebugLog.simplificationPass("fixed-point reached", fixedPoint.getValue()); ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint); DebugLog.simplificationPass("remove redundant &&", simplified.getValue()); @@ -36,16 +39,21 @@ public static ValDerivationNode simplify(Expression exp, Map a return expanded; } + public static ValDerivationNode simplify(Expression exp, Map aliases) { + return simplify(exp, aliases, List.of()); + } + public static ValDerivationNode simplify(Expression exp) { - return simplify(exp, Map.of()); + return simplify(exp, Map.of(), List.of()); } /** * Recursively applies propagation and folding until the expression stops changing (fixed point) Stops early if the * expression simplifies to a boolean literal, which means we've simplified too much. */ - private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) { - ValDerivationNode simplified = simplifyOnce(current, prevExp); + private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp, + List ghostStates) { + ValDerivationNode simplified = simplifyOnce(current, prevExp, ghostStates); Expression currExp = simplified.getValue(); // fixed point reached — compare on toString() because propagate/fold/reduce mutate the AST in place, so a @@ -60,15 +68,25 @@ private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, } // continue simplifying - return simplifyToFixedPoint(simplified, simplified.getValue()); + return simplifyToFixedPoint(simplified, simplified.getValue(), ghostStates); } - private static ValDerivationNode simplifyOnce(ValDerivationNode current, Expression prevExp) { - ValDerivationNode prop = VariablePropagation.propagate(prevExp, current); + private static ValDerivationNode simplifyOnce(ValDerivationNode current, Expression prevExp, + List ghostStates) { + // Propagation is told not to collapse ghost-state equalities (stateN(x) -> stateN(y)): otherwise a + // chain state1(a)==state1(b) && state1(b)==state1(c) would lose its shared middle term before + // derivation could consume both links. + ValDerivationNode prop = VariablePropagation.propagate(prevExp, current, + StateDerivation.internalStateFunctions(ghostStates)); DebugLog.simplificationPass("variable propagation", prop.getValue()); ValDerivationNode fold = ExpressionFolding.fold(prop); DebugLog.simplificationPass("expression folding", fold.getValue()); - ValDerivationNode simplified = simplifyValDerivationNode(fold); + // Derivation runs after folding: state conjuncts start life as unresolved ?: ternaries, and folding + // turns them into the concrete FunctionInvocation facts derivation matches. The surrounding + // fixed-point loop re-runs this pass, so equality chains resolve link by link across iterations. + ValDerivationNode derived = StateDerivation.derive(fold, ghostStates); + DebugLog.simplificationPass("state derivation", derived.getValue()); + ValDerivationNode simplified = simplifyValDerivationNode(derived); DebugLog.simplificationPass("remove redundant && (loop)", simplified.getValue()); return simplified; } diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/StateDerivation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/StateDerivation.java new file mode 100644 index 00000000..098c584a --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/StateDerivation.java @@ -0,0 +1,212 @@ +package liquidjava.rj_language.opt; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import liquidjava.processor.context.GhostState; +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.FunctionInvocation; +import liquidjava.rj_language.ast.GroupExpression; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.StateDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.utils.Utils; +import liquidjava.utils.constants.Ops; + +/** + * Simplification pass that rewrites internal ghost-state equalities into developer-facing state predicates. + * + *

+ * Typestate found-state conjunctions mix developer state invocations ({@code startTiling(p)}) with internal ghost-state + * equalities ({@code state1(a) == state1(b)}). The equalities are an SMT artifact, meaningless to the developer. When a + * known developer state holds for one side of such an equality, the same state must hold for the other side; this pass + * derives that state predicate and drops the equality. + * + *

+ * It is a single pass: it only consults the conjuncts it is given. Chain resolution + * ({@code state1(a)==state1(b) && state1(b)==state1(c)}) falls out of the surrounding {@code simplifyToFixedPoint} loop + * re-running the pass until stable. + * + *

+ * Derivation only matches plain {@link FunctionInvocation} conjuncts and plain {@code stateN(x) == stateN(y)} + * equalities. It deliberately does not traverse into {@code Ite} / disjunction / negation — a conditional branch + * is not an asserted fact, so deriving from it would be unsound. + */ +public class StateDerivation { + + /** + * Rewrites ghost-state equalities in {@code node} into derived developer-state conjuncts. Returns {@code node} + * unchanged when nothing can be derived, so the enclosing fixed-point loop terminates. + */ + public static ValDerivationNode derive(ValDerivationNode node, List ghostStates) { + if (ghostStates == null || ghostStates.isEmpty()) + return node; + Map stateToInternal = buildStateMap(ghostStates); + if (stateToInternal.isEmpty()) + return node; + + List conjuncts = flatten(node); + List kept = new ArrayList<>(); + List derived = new ArrayList<>(); + boolean changed = false; + + for (ValDerivationNode conjunct : conjuncts) { + FunctionInvocation[] equality = stateEquality(conjunct.getValue()); + if (equality != null) { + List fromEquality = deriveFromEquality(equality[0], equality[1], conjunct, conjuncts, + stateToInternal); + if (!fromEquality.isEmpty()) { + derived.addAll(fromEquality); + changed = true; + continue; // drop the equality — it has been rewritten as developer states + } + } + kept.add(conjunct); + } + + if (!changed) + return node; + + List result = new ArrayList<>(kept); + for (ValDerivationNode d : derived) { + if (result.stream().noneMatch(c -> sameValue(c, d))) + result.add(d); + } + return rebuildConjunction(result, node); + } + + /** + * The set of internal {@code stateN} function names backing the given developer states. Used to tell variable + * propagation which ghost-state equalities it must leave intact for derivation to consume. + */ + public static Set internalStateFunctions(List ghostStates) { + Set names = new HashSet<>(); + if (ghostStates == null) + return names; + for (String internal : buildStateMap(ghostStates).values()) + names.add(Utils.getSimpleName(internal)); + return names; + } + + /** Maps each developer state name to the internal {@code stateN} function backing its refinement. */ + private static Map buildStateMap(List ghostStates) { + Map map = new HashMap<>(); + for (GhostState gs : ghostStates) { + if (gs.getRefinement() == null) + continue; + Expression ref = unwrap(gs.getRefinement().getExpression()); + if (ref instanceof BinaryExpression be && Ops.EQ.equals(be.getOperator())) { + FunctionInvocation state = functionInvocation(be.getFirstOperand()); + if (state == null) + state = functionInvocation(be.getSecondOperand()); + if (state != null) + map.put(gs.getName(), state.getName()); + } + } + return map; + } + + /** Returns {@code {left, right}} when {@code exp} is {@code stateN(x) == stateN(y)}; otherwise {@code null}. */ + private static FunctionInvocation[] stateEquality(Expression exp) { + Expression e = unwrap(exp); + if (!(e instanceof BinaryExpression be) || !Ops.EQ.equals(be.getOperator())) + return null; + FunctionInvocation left = functionInvocation(be.getFirstOperand()); + FunctionInvocation right = functionInvocation(be.getSecondOperand()); + if (left == null || right == null || left.getArgs().size() != 1 || right.getArgs().size() != 1) + return null; + if (!sameName(left.getName(), right.getName())) + return null; + return new FunctionInvocation[] { left, right }; + } + + /** + * For an equality {@code stateN(x) == stateN(y)}, finds every known developer state holding for one operand and + * derives the same state for the other operand. Each derived node carries a {@link StateDerivationNode} recording + * the equality and the known state it came from. + */ + private static List deriveFromEquality(FunctionInvocation left, FunctionInvocation right, + ValDerivationNode equalityNode, List conjuncts, Map stateToInternal) { + List derived = new ArrayList<>(); + String internal = left.getName(); + for (ValDerivationNode conjunct : conjuncts) { + FunctionInvocation state = functionInvocation(conjunct.getValue()); + if (state == null || state.getArgs().size() != 1) + continue; + for (Map.Entry entry : stateToInternal.entrySet()) { + if (!sameName(entry.getKey(), state.getName()) || !sameName(entry.getValue(), internal)) + continue; + Expression known = state.getArgs().get(0); + Expression target = known.equals(left.getArgs().get(0)) ? right.getArgs().get(0) + : known.equals(right.getArgs().get(0)) ? left.getArgs().get(0) : null; + if (target != null) + derived.add(new ValDerivationNode(invocation(state.getName(), target), + new StateDerivationNode(equalityNode, conjunct))); + } + } + return derived; + } + + private static FunctionInvocation invocation(String name, Expression arg) { + List args = new ArrayList<>(); + args.add(arg.clone()); + return new FunctionInvocation(name, args); + } + + /** Splits a left-associated {@code &&} derivation tree into its conjunct nodes, preserving their origins. */ + private static List flatten(ValDerivationNode node) { + List out = new ArrayList<>(); + flattenInto(node, out); + return out; + } + + private static void flattenInto(ValDerivationNode node, List out) { + Expression value = node.getValue(); + if (value instanceof BinaryExpression be && Ops.AND.equals(be.getOperator())) { + if (node.getOrigin()instanceof BinaryDerivationNode bin) { + flattenInto(bin.getLeft(), out); + flattenInto(bin.getRight(), out); + } else { + flattenInto(new ValDerivationNode(be.getFirstOperand(), null), out); + flattenInto(new ValDerivationNode(be.getSecondOperand(), null), out); + } + } else { + out.add(node); + } + } + + /** Rebuilds a left-associated {@code &&} conjunction from {@code conjuncts}. */ + private static ValDerivationNode rebuildConjunction(List conjuncts, ValDerivationNode fallback) { + if (conjuncts.isEmpty()) + return fallback; + ValDerivationNode acc = conjuncts.get(0); + for (int i = 1; i < conjuncts.size(); i++) { + ValDerivationNode next = conjuncts.get(i); + Expression value = new BinaryExpression(acc.getValue(), Ops.AND, next.getValue()); + acc = new ValDerivationNode(value, new BinaryDerivationNode(acc, next, Ops.AND)); + } + return acc; + } + + private static FunctionInvocation functionInvocation(Expression exp) { + Expression e = unwrap(exp); + return e instanceof FunctionInvocation fi ? fi : null; + } + + private static Expression unwrap(Expression exp) { + return exp instanceof GroupExpression ge ? unwrap(ge.getExpression()) : exp; + } + + private static boolean sameName(String first, String second) { + return first.equals(second) || Utils.getSimpleName(first).equals(Utils.getSimpleName(second)); + } + + private static boolean sameValue(ValDerivationNode a, ValDerivationNode b) { + return a.getValue().toString().equals(b.getValue().toString()); + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java index 48b37e03..dd66080c 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java @@ -15,6 +15,9 @@ import java.util.HashMap; import java.util.Map; +import java.util.Set; + +import liquidjava.utils.Utils; public class VariablePropagation { @@ -24,7 +27,27 @@ public class VariablePropagation { * steps taken. */ public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) { + return propagate(exp, previousOrigin, Set.of()); + } + + /** + * Variant of {@link #propagate(Expression, ValDerivationNode)} that leaves ghost-state equalities intact. + * + *

+ * Substitutions of the form {@code stateN(x) -> stateN(y)} (both sides invocations of a ghost-state function named + * in {@code protectedStateFunctions}) are dropped before propagation. Collapsing them would erase the shared middle + * term of an equality chain ({@code state1(a)==state1(b) && state1(b)==state1(c)}), which state-equality derivation + * needs whole. These equalities are an SMT artifact consumed only by error-message derivation. + */ + public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin, + Set protectedStateFunctions) { Map substitutions = VariableResolver.resolve(exp); + if (!protectedStateFunctions.isEmpty()) { + substitutions.entrySet() + .removeIf(e -> isProtectedStateInvocation(e.getKey(), protectedStateFunctions) + && e.getValue()instanceof FunctionInvocation fi + && protectedStateFunctions.contains(Utils.getSimpleName(fi.getName()))); + } Map directSubstitutions = new HashMap<>(); // var == literal or var == var Map expressionSubstitutions = new HashMap<>(); // var == expression for (Map.Entry entry : substitutions.entrySet()) { @@ -46,6 +69,12 @@ public static ValDerivationNode propagate(Expression exp, ValDerivationNode prev return propagateRecursive(exp, activeSubstitutions, varOrigins); } + /** A substitution key {@code f(...)} whose function {@code f} is one of the protected ghost-state functions. */ + private static boolean isProtectedStateInvocation(String key, Set stateFunctions) { + int paren = key.indexOf('('); + return paren > 0 && stateFunctions.contains(Utils.getSimpleName(key.substring(0, paren))); + } + /** * Recursively performs propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2) */ diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/StateDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/StateDerivationNode.java new file mode 100644 index 00000000..fd4190a6 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/StateDerivationNode.java @@ -0,0 +1,32 @@ +package liquidjava.rj_language.opt.derivation_node; + +/** + * Origin of a developer-facing state predicate that was derived from a ghost-state equality. + * + *

+ * When a found-state conjunction holds a known developer state for one object ({@code knownState(a)}) together with an + * internal ghost-state equality ({@code state1(a) == state1(b)}), the same state must hold for the other object. The + * derivation pass rewrites this into {@code knownState(b)} and attaches a {@code StateDerivationNode} so the step is + * explainable: {@code knownState(b)} was derived through {@link #getSourceEquality()} starting from + * {@link #getSourceState()}. + */ +public class StateDerivationNode extends DerivationNode { + + private final ValDerivationNode sourceEquality; + private final ValDerivationNode sourceState; + + public StateDerivationNode(ValDerivationNode sourceEquality, ValDerivationNode sourceState) { + this.sourceEquality = sourceEquality; + this.sourceState = sourceState; + } + + /** The ghost-state equality ({@code stateN(x) == stateN(y)}) the derivation went through. */ + public ValDerivationNode getSourceEquality() { + return sourceEquality; + } + + /** The known developer state the derivation started from. */ + public ValDerivationNode getSourceState() { + return sourceState; + } +} diff --git a/liquidjava-verifier/src/test/java/liquidjava/api/tests/DerivedStateErrorMessageTest.java b/liquidjava-verifier/src/test/java/liquidjava/api/tests/DerivedStateErrorMessageTest.java new file mode 100644 index 00000000..7a321caf --- /dev/null +++ b/liquidjava-verifier/src/test/java/liquidjava/api/tests/DerivedStateErrorMessageTest.java @@ -0,0 +1,48 @@ +package liquidjava.api.tests; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +import liquidjava.api.CommandLineLauncher; +import liquidjava.diagnostics.Diagnostics; +import liquidjava.diagnostics.errors.StateRefinementError; + +/** + * End-to-end check that {@link StateRefinementError} diagnostics present the found-state conjunction with + * developer-facing typestate names — not the internal {@code stateN(x) == stateN(y)} equalities the SMT layer threads + * across call boundaries. + * + *

+ * {@code TestExamples} only matches an error's title and line, never its message body, so the state-equality derivation + * rewrite needs a dedicated assertion on the rendered text. The {@code imagewrite_error} scenario is a two-dimension + * external-typestate spec whose found-state relates the same {@code param} across several SSA versions — exactly the + * shape that produces those equalities. + */ +class DerivedStateErrorMessageTest { + + private static final String IMAGEWRITE_ERROR = "../liquidjava-example/src/main/java/testSuite/classes/imagewrite_error"; + + @Test + void stateRefinementErrorShowsDeveloperStatesNotInternalEqualities() { + CommandLineLauncher.launch(IMAGEWRITE_ERROR); + + StateRefinementError error = Diagnostics.getInstance().getErrors().stream() + .filter(StateRefinementError.class::isInstance).map(StateRefinementError.class::cast).findFirst() + .orElseThrow(() -> new AssertionError("expected a StateRefinementError from imagewrite_error")); + + // the rendered, developer-visible message string ("Expected state ... but found ...") + String message = error.getMessage(); + + // derivation rewrote every cross-version ghost-state equality into a named typestate + assertTrue(message.contains("startTiling("), "message should name the tiling typestate, got: " + message); + assertTrue(message.contains("compressionExplicit("), + "message should name the compression typestate, got: " + message); + + // the internal ghost-state functions and their equalities must not leak into the diagnostic + assertFalse(message.contains("state1("), "internal state1(...) leaked into the message: " + message); + assertFalse(message.contains("state2("), "internal state2(...) leaked into the message: " + message); + assertFalse(message.contains("=="), "a raw ghost-state equality leaked into the message: " + message); + } +} diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/DerivedStateEqualitiesTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/DerivedStateEqualitiesTest.java new file mode 100644 index 00000000..b27f2f47 --- /dev/null +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/DerivedStateEqualitiesTest.java @@ -0,0 +1,183 @@ +package liquidjava.rj_language.opt; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import liquidjava.processor.context.GhostState; +import liquidjava.processor.facade.AliasDTO; +import liquidjava.rj_language.Predicate; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.StateDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.rj_language.parsing.RefinementsParser; + +/** + * Test suite for state-equality derivation inside the expression simplifier. + * + * Derivation rewrites internal ghost-state equalities (e.g. {@code state1(a) == state1(b)}) into developer-facing state + * predicates (e.g. {@code knownState(b)}) so {@code StateRefinementError} diagnostics show meaningful typestate names. + */ +class DerivedStateEqualitiesTest { + + private static Expression parse(String sut) { + return RefinementsParser.createAST(sut, ""); + } + + /** + * Builds a developer state named {@code stateName} whose refinement is backed by the internal ghost-state function + * {@code internalName} (the {@code stateN} dimension). + */ + private static GhostState ghostState(String stateName, String internalName) { + GhostState gs = new GhostState(stateName, null, null, "", ""); + gs.setRefinement(new Predicate(parse(internalName + "(wild) == 0"))); + return gs; + } + + @Test + void testBasicDerivationDerivesKnownStateForOtherOperand() { + Expression expression = parse("state1(a) == state1(b) && knownState(a)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + assertEquals("knownState(a) && knownState(b)", result.getValue().toString(), + "Equality should be dropped and knownState(b) derived from knownState(a)"); + } + + @Test + void testDerivationMatchesRightOperandOfEquality() { + Expression expression = parse("state1(a) == state1(b) && knownState(b)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + assertEquals("knownState(b) && knownState(a)", result.getValue().toString(), + "A known state on the right operand should derive the state for the left operand"); + } + + @Test + void testChainDerivationResolvesTransitiveEqualities() { + Expression expression = parse("state1(a) == state1(b) && state1(b) == state1(c) && knownState(a)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + String actual = result.getValue().toString(); + assertTrue(actual.contains("knownState(b)"), "chain should derive knownState(b), got: " + actual); + assertTrue(actual.contains("knownState(c)"), "chain should derive knownState(c), got: " + actual); + } + + @Test + void testNoMatchingKnownStateKeepsEquality() { + Expression expression = parse("state1(a) == state1(b) && knownState(c)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + assertEquals("state1(a) == state1(b) && knownState(c)", result.getValue().toString(), + "With no known state for either operand the equality is kept and nothing is derived"); + } + + @Test + void testDerivationFiresOnResolvedTernaryBranchNotRawIte() { + // The state conjunct is a conditional; mode == 2 makes the (mode == 1 ? ...) condition false, so the + // asserted state is the else-branch otherState(a). Derivation must fire on that resolved fact, never + // on the raw Ite (which would wrongly derive knownState from the unreachable then-branch). + Expression expression = parse( + "mode == 2 && (mode == 1 ? knownState(a) : otherState(a)) && state1(a) == state1(b)"); + List ghostStates = List.of(ghostState("knownState", "state1"), ghostState("otherState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + String actual = result.getValue().toString(); + assertEquals("otherState(a) && otherState(b)", actual, "Derivation should fire on the resolved else-branch"); + assertFalse(actual.contains("knownState"), "Must not derive from the unreachable then-branch: " + actual); + } + + @Test + void testDerivedStateAlreadyPresentIsNotDuplicated() { + Expression expression = parse("state1(a) == state1(b) && knownState(a) && knownState(b)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + assertEquals("knownState(a) && knownState(b)", result.getValue().toString(), + "A derived state already present as a conjunct must not be duplicated"); + } + + @Test + void testTwoIndependentGhostDimensionsDeriveSeparately() { + Expression expression = parse( + "state1(a) == state1(b) && state2(a) == state2(b) && dim1State(a) && dim2State(a)"); + List ghostStates = List.of(ghostState("dim1State", "state1"), ghostState("dim2State", "state2")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + assertEquals("dim1State(a) && dim2State(a) && dim1State(b) && dim2State(b)", result.getValue().toString(), + "Each ghost dimension should derive from its own equality only"); + } + + @Test + void testDerivedStateCarriesStateDerivationNodeProvenance() { + Expression expression = parse("state1(a) == state1(b) && knownState(a)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + assertEquals("knownState(a) && knownState(b)", result.getValue().toString()); + assertInstanceOf(BinaryDerivationNode.class, result.getOrigin()); + BinaryDerivationNode conjunction = (BinaryDerivationNode) result.getOrigin(); + + // The conjunct that was not derived keeps its original (here origin-less) node. + ValDerivationNode kept = conjunction.getLeft(); + assertEquals("knownState(a)", kept.getValue().toString()); + assertNull(kept.getOrigin(), "A conjunct that was not derived should keep its original node"); + + // The derived conjunct carries a StateDerivationNode recording how it was derived. + ValDerivationNode derived = conjunction.getRight(); + assertEquals("knownState(b)", derived.getValue().toString()); + assertInstanceOf(StateDerivationNode.class, derived.getOrigin(), + "A derived state should carry a StateDerivationNode origin"); + StateDerivationNode provenance = (StateDerivationNode) derived.getOrigin(); + assertEquals("state1(a) == state1(b)", provenance.getSourceEquality().getValue().toString(), + "Provenance should reference the equality the state was derived through"); + assertEquals("knownState(a)", provenance.getSourceState().getValue().toString(), + "Provenance should reference the known state the derivation started from"); + } + + @Test + void testDerivationDoesNotFireFromDisjunctionBranch() { + // A state inside a disjunction is not an asserted fact: the object is in knownState OR otherState, + // not definitely either. Derivation must not carry such a state across the equality. + Expression expression = parse("state1(a) == state1(b) && (knownState(a) || otherState(a))"); + List ghostStates = List.of(ghostState("knownState", "state1"), ghostState("otherState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + String actual = result.getValue().toString(); + assertFalse(actual.contains("knownState(b)"), "Must not derive a state from a disjunction branch: " + actual); + assertFalse(actual.contains("otherState(b)"), "Must not derive a state from a disjunction branch: " + actual); + assertTrue(actual.contains("state1(a) == state1(b)"), + "The equality must be kept when nothing can be soundly derived: " + actual); + } + + @Test + void testDerivationDoesNotFireFromNegatedState() { + // !knownState(a) asserts the object is NOT in knownState; it is not a known state to carry across + // the equality. Derivation must leave the equality intact. + Expression expression = parse("state1(a) == state1(b) && !knownState(a)"); + List ghostStates = List.of(ghostState("knownState", "state1")); + + ValDerivationNode result = ExpressionSimplifier.simplify(expression, Map.of(), ghostStates); + + String actual = result.getValue().toString(); + assertFalse(actual.contains("knownState(b)"), "Must not derive a state from a negated state: " + actual); + assertTrue(actual.contains("state1(a) == state1(b)"), + "The equality must be kept when nothing can be soundly derived: " + actual); + } +} diff --git a/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java b/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java index 35e1ead6..ae4a93ae 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java +++ b/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java @@ -17,6 +17,7 @@ import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; import liquidjava.rj_language.opt.derivation_node.DerivationNode; import liquidjava.rj_language.opt.derivation_node.IteDerivationNode; +import liquidjava.rj_language.opt.derivation_node.StateDerivationNode; import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; @@ -126,6 +127,12 @@ public static void assertDerivationEquals(DerivationNode expected, DerivationNod assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition"); assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then"); assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else"); + } else if (expected instanceof StateDerivationNode expectedState) { + StateDerivationNode actualState = (StateDerivationNode) actual; + assertDerivationEquals(expectedState.getSourceEquality(), actualState.getSourceEquality(), + message + " > sourceEquality"); + assertDerivationEquals(expectedState.getSourceState(), actualState.getSourceState(), + message + " > sourceState"); } }