Implement transfer function for interpreting nullptr, address_of (&) operator, comparison between pointers.

PiperOrigin-RevId: 453859336
diff --git a/nullability_verification/pointer_nullability_analysis.cc b/nullability_verification/pointer_nullability_analysis.cc
index 5cde0c3..287365a 100644
--- a/nullability_verification/pointer_nullability_analysis.cc
+++ b/nullability_verification/pointer_nullability_analysis.cc
@@ -4,7 +4,6 @@
 
 #include "nullability_verification/pointer_nullability_analysis.h"
 
-#include <iostream>
 #include <string>
 
 #include "common/check.h"
@@ -12,8 +11,10 @@
 #include "nullability_verification/pointer_nullability_matchers.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Expr.h"
+#include "clang/AST/OperationKinds.h"
+#include "clang/AST/Stmt.h"
+#include "clang/AST/Type.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
-#include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
 #include "clang/Analysis/FlowSensitive/MatchSwitch.h"
 #include "clang/Analysis/FlowSensitive/Value.h"
@@ -34,6 +35,14 @@
 
 namespace {
 
+BoolValue& getPointerNullability(
+    const Expr* PointerExpr, TransferState<PointerNullabilityLattice>& State) {
+  auto* PointerVal =
+      cast<PointerValue>(State.Env.getValue(*PointerExpr, SkipPast::Reference));
+  CHECK(State.Lattice.hasPointerNullability(PointerVal));
+  return *State.Lattice.getPointerNullability(PointerVal);
+}
+
 void initialisePointerNullability(
     const Expr* Expr, const MatchFinder::MatchResult&,
     TransferState<PointerNullabilityLattice>& State) {
@@ -46,37 +55,77 @@
   }
 }
 
+void transferNullPointerLiteral(
+    const Expr* NullPointer, const MatchFinder::MatchResult& Result,
+    TransferState<PointerNullabilityLattice>& State) {
+  auto* NullPointerVal = cast_or_null<PointerValue>(
+      State.Env.getValue(*NullPointer, SkipPast::None));
+  if (NullPointerVal == nullptr) {
+    // Create storage location and value for null pointer if it doesn't exist
+    auto& NullPointerLoc = State.Env.createStorageLocation(*NullPointer);
+    NullPointerVal = &State.Env.takeOwnership(
+        std::make_unique<PointerValue>(NullPointerLoc));
+    State.Env.setStorageLocation(*NullPointer, NullPointerLoc);
+    State.Env.setValue(NullPointerLoc, *NullPointerVal);
+  }
+  if (!State.Lattice.hasPointerNullability(NullPointerVal)) {
+    // Set null pointer to be known null if not already set
+    State.Lattice.setPointerNullability(NullPointerVal,
+                                        &State.Env.getBoolLiteralValue(false));
+  }
+}
+
+void transferAddrOf(const UnaryOperator* UnaryOp,
+                    const MatchFinder::MatchResult& Result,
+                    TransferState<PointerNullabilityLattice>& State) {
+  auto* PointerVal =
+      cast<PointerValue>(State.Env.getValue(*UnaryOp, SkipPast::None));
+  State.Lattice.setPointerNullability(PointerVal,
+                                      &State.Env.getBoolLiteralValue(true));
+}
+
 void transferDereference(const UnaryOperator* UnaryOp,
                          const MatchFinder::MatchResult&,
                          TransferState<PointerNullabilityLattice>& State) {
   auto* PointerExpr = UnaryOp->getSubExpr();
-  if (auto* PointerVal = cast_or_null<PointerValue>(
-          State.Env.getValue(*PointerExpr, SkipPast::Reference))) {
-    auto PointerNullability = State.Lattice.getPointerNullability(PointerVal);
-    CHECK(PointerNullability != nullptr);
-    if (State.Env.flowConditionImplies(*PointerNullability)) {
-      return;
-    }
+  auto& PointerNullability = getPointerNullability(PointerExpr, State);
+  if (!State.Env.flowConditionImplies(PointerNullability)) {
+    State.Lattice.addViolation(PointerExpr);
   }
-  State.Lattice.addViolation(PointerExpr);
 }
 
 void transferNullCheckComparison(
-    const Expr* NullCheck, const Expr* PointerExpr,
+    const BinaryOperator* BinaryOp, const MatchFinder::MatchResult& result,
     TransferState<PointerNullabilityLattice>& State) {
-  if (auto* PointerVal = cast_or_null<PointerValue>(
-          State.Env.getValue(*PointerExpr, SkipPast::Reference))) {
-    auto* PointerNullability = State.Lattice.getPointerNullability(PointerVal);
-    CHECK(PointerNullability != nullptr);
+  // Boolean representing the comparison between the two pointer values,
+  // automatically created by the dataflow framework
+  auto& PointerComparison =
+      *cast<BoolValue>(State.Env.getValue(*BinaryOp, SkipPast::None));
 
-    // For binary operations, the dataflow framework automatically creates a
-    // corresponding BoolVal
-    auto* ExistingDFVal =
-        cast_or_null<BoolValue>(State.Env.getValue(*NullCheck, SkipPast::None));
-    CHECK(ExistingDFVal != nullptr);
-    State.Env.addToFlowCondition(
-        State.Env.makeIff(*ExistingDFVal, *PointerNullability));
-  }
+  CHECK(BinaryOp->getOpcode() == BO_EQ || BinaryOp->getOpcode() == BO_NE);
+  auto& PointerEQ = BinaryOp->getOpcode() == BO_EQ
+                        ? PointerComparison
+                        : State.Env.makeNot(PointerComparison);
+  auto& PointerNE = BinaryOp->getOpcode() == BO_EQ
+                        ? State.Env.makeNot(PointerComparison)
+                        : PointerComparison;
+
+  auto& LHSNullability = getPointerNullability(BinaryOp->getLHS(), State);
+  auto& RHSNullability = getPointerNullability(BinaryOp->getRHS(), State);
+
+  // !LHS && !RHS => LHS == RHS
+  State.Env.addToFlowCondition(State.Env.makeImplication(
+      State.Env.makeAnd(State.Env.makeNot(LHSNullability),
+                        State.Env.makeNot(RHSNullability)),
+      PointerEQ));
+  // !LHS && RHS => LHS != RHS
+  State.Env.addToFlowCondition(State.Env.makeImplication(
+      State.Env.makeAnd(State.Env.makeNot(LHSNullability), RHSNullability),
+      PointerNE));
+  // LHS && !RHS => LHS != RHS
+  State.Env.addToFlowCondition(State.Env.makeImplication(
+      State.Env.makeAnd(LHSNullability, State.Env.makeNot(RHSNullability)),
+      PointerNE));
 }
 
 void transferNullCheckImplicitCastPtrToBool(
@@ -95,20 +144,19 @@
 
 auto buildTransferer() {
   return MatchSwitchBuilder<TransferState<PointerNullabilityLattice>>()
-      // Initialise nullability state of pointers
-      .CaseOf<Expr>(isPointerExpr(), initialisePointerNullability)
-      // Pointer dereference
+      // Handles initialization of the null states of pointers
+      .CaseOf<Expr>(isPointerVariableReference(), initialisePointerNullability)
+      // Handles nullptr
+      .CaseOf<Expr>(isNullPointerLiteral(), transferNullPointerLiteral)
+      // Handles address of operator (&var)
+      .CaseOf<UnaryOperator>(isAddrOf(), transferAddrOf)
+      // Handles pointer dereferencing (*ptr)
       .CaseOf<UnaryOperator>(isPointerDereference(), transferDereference)
-      // Nullability check
-      .CaseOf<BinaryOperator>(
-          isNEQNullBinOp(/*BindID=*/"pointer"),
-          [](const BinaryOperator* binOp,
-             const MatchFinder::MatchResult& result,
-             TransferState<PointerNullabilityLattice>& State) {
-            transferNullCheckComparison(
-                binOp, result.Nodes.getNodeAs<Expr>("pointer"), State);
-          })
-      .CaseOf<Expr>(isImplicitCastPtrToBool(),
+      // Handles comparison between 2 pointers
+      .CaseOf<BinaryOperator>(isPointerCheckBinOp(),
+                              transferNullCheckComparison)
+      // Handles checking of pointer as boolean
+      .CaseOf<Expr>(isImplicitCastPointerToBool(),
                     transferNullCheckImplicitCastPtrToBool)
       .Build();
 }