diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcLibraryHelper.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcLibraryHelper.java
index e1352d5..dbed5de 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcLibraryHelper.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcLibraryHelper.java
@@ -1174,9 +1174,8 @@
               ? CppHelper.createDefaultCppModuleMap(ruleContext)
               : injectedCppModuleMap;
       contextBuilder.setCppModuleMap(cppModuleMap);
-      boolean useModules = featureConfiguration.isEnabled(CppRuleClasses.USE_HEADER_MODULES);
-      contextBuilder.setUseHeaderModules(useModules);
-      if (useModules && featureConfiguration.isEnabled(CppRuleClasses.TRANSITIVE_MODULE_MAPS)) {
+      if (featureConfiguration.isEnabled(CppRuleClasses.USE_HEADER_MODULES)
+          && featureConfiguration.isEnabled(CppRuleClasses.TRANSITIVE_MODULE_MAPS)) {
         contextBuilder.setProvideTransitiveModuleMaps(true);
       }
       if (createModuleMapActions) {
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompilationContext.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompilationContext.java
index ac8f38b..b1cd506 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompilationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompilationContext.java
@@ -70,9 +70,6 @@
   // True if this context is for a compilation that needs transitive module maps.
   private final boolean provideTransitiveModuleMaps;
   
-  // True if this context is for a compilation that should use header modules from dependencies.
-  private final boolean useHeaderModules;
-
   // Derived from depsContexts.
   private final ImmutableSet<Artifact> compilationPrerequisites;
 
@@ -88,8 +85,7 @@
       NestedSet<Artifact> transitiveModuleMaps,
       NestedSet<Artifact> directModuleMaps,
       CppModuleMap cppModuleMap,
-      boolean provideTransitiveModuleMaps,
-      boolean useHeaderModules) {
+      boolean provideTransitiveModuleMaps) {
     Preconditions.checkNotNull(commandLineContext);
     this.commandLineContext = commandLineContext;
     this.declaredIncludeDirs = declaredIncludeDirs;
@@ -102,7 +98,6 @@
     this.transitiveModuleMaps = transitiveModuleMaps;
     this.cppModuleMap = cppModuleMap;
     this.provideTransitiveModuleMaps = provideTransitiveModuleMaps;
-    this.useHeaderModules = useHeaderModules;
     this.compilationPrerequisites = compilationPrerequisites;
   }
 
@@ -292,8 +287,7 @@
         context.transitiveModuleMaps,
         context.directModuleMaps,
         context.cppModuleMap,
-        context.provideTransitiveModuleMaps,
-        context.useHeaderModules);
+        context.provideTransitiveModuleMaps);
   }
 
   /**
@@ -344,8 +338,7 @@
         mergeSets(ownerContext.transitiveModuleMaps, libContext.transitiveModuleMaps),
         mergeSets(ownerContext.directModuleMaps, libContext.directModuleMaps),
         libContext.cppModuleMap,
-        libContext.provideTransitiveModuleMaps,
-        libContext.useHeaderModules);
+        libContext.provideTransitiveModuleMaps);
   }
   
   /**
@@ -363,11 +356,6 @@
     return cppModuleMap;
   }
 
-  /** @return whether header modules should be used in this context. */
-  public boolean getUseHeaderModules() {
-    return useHeaderModules;
-  }
-
   /**
    * The parts of the compilation context that influence the command line of
    * compilation actions.
@@ -414,7 +402,6 @@
     private final Set<String> defines = new LinkedHashSet<>();
     private CppModuleMap cppModuleMap;
     private boolean provideTransitiveModuleMaps = false;
-    private boolean useHeaderModules = false;
 
     /** The rule that owns the context */
     private final RuleContext ruleContext;
@@ -681,15 +668,6 @@
     }
     
     /**
-     * Sets that the context will be used by a compilation that uses header modules provided by
-     * its dependencies.
-     */
-    public Builder setUseHeaderModules(boolean useHeaderModules) {
-      this.useHeaderModules = useHeaderModules;
-      return this;
-    }
-
-    /**
      * Builds the {@link CppCompilationContext}.
      */
     public CppCompilationContext build() {
@@ -725,8 +703,7 @@
           transitiveModuleMaps.build(),
           directModuleMaps.build(),
           cppModuleMap,
-          provideTransitiveModuleMaps,
-          useHeaderModules);
+          provideTransitiveModuleMaps);
     }
 
     /**
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java
index 1ef044f..8a03ceb 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java
@@ -172,6 +172,7 @@
   private final boolean shouldScanIncludes;
   private final boolean shouldPruneModules;
   private final boolean usePic;
+  private final boolean useHeaderModules;
   private final CppCompilationContext context;
   private final Iterable<IncludeScannable> lipoScannables;
   private final ImmutableList<Artifact> builtinIncludeFiles;
@@ -260,6 +261,7 @@
       boolean shouldScanIncludes,
       boolean shouldPruneModules,
       boolean usePic,
+      boolean useHeaderModules,
       Label sourceLabel,
       NestedSet<Artifact> mandatoryInputs,
       Artifact outputFile,
@@ -289,7 +291,7 @@
             ruleContext,
             mandatoryInputs,
             context.getTransitiveCompilationPrerequisites(),
-            context.getUseHeaderModules() && !cppConfiguration.getSkipUnusedModules()
+            useHeaderModules && !cppConfiguration.getSkipUnusedModules()
                 ? context.getTransitiveModules(usePic)
                 : null,
             optionalSourceFile,
@@ -312,6 +314,7 @@
     this.shouldScanIncludes = shouldScanIncludes;
     this.shouldPruneModules = shouldPruneModules;
     this.usePic = usePic;
+    this.useHeaderModules = useHeaderModules;
     this.inputsKnown = !shouldScanIncludes;
     this.cppCompileCommandLine =
         new CppCompileCommandLine(
@@ -533,8 +536,7 @@
 
   @Override
   public Iterable<Artifact> getInputsWhenSkippingInputDiscovery() {
-    if (context.getUseHeaderModules()
-        && cppConfiguration.getSkipUnusedModules()) {
+    if (useHeaderModules && cppConfiguration.getSkipUnusedModules()) {
       return context.getTransitiveModules(usePic);
     }
     return null;
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java
index 7a1e60b..b3af498 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java
@@ -71,6 +71,7 @@
   private AnalysisEnvironment analysisEnvironment;
   private ImmutableList<PathFragment> extraSystemIncludePrefixes = ImmutableList.of();
   private boolean usePic;
+  private boolean allowUsingHeaderModules;
   private SpecialInputsHandler specialInputsHandler = CppCompileAction.VOID_SPECIAL_INPUTS_HANDLER;
   private UUID actionClassId = GUID;
   private Class<? extends CppCompileActionContext> actionContext;
@@ -99,6 +100,7 @@
     this.mandatoryInputsBuilder = NestedSetBuilder.stableOrder();
     this.lipoScannableMap = getLipoScannableMap(ruleContext);
     this.ruleContext = ruleContext;
+    this.allowUsingHeaderModules = true;
 
     features.addAll(ruleContext.getFeatures());
   }
@@ -146,6 +148,7 @@
     this.actionContext = other.actionContext;
     this.cppConfiguration = other.cppConfiguration;
     this.usePic = other.usePic;
+    this.allowUsingHeaderModules = other.allowUsingHeaderModules;
     this.lipoScannableMap = other.lipoScannableMap;
     this.ruleContext = other.ruleContext;
     this.shouldScanIncludes = other.shouldScanIncludes;
@@ -248,6 +251,9 @@
     // This must be set either to false or true by CppSemantics, otherwise someone forgot to call
     // finalizeCompileActionBuilder on this builder.
     Preconditions.checkNotNull(shouldScanIncludes);
+    boolean useHeaderModules =
+        allowUsingHeaderModules
+            && featureConfiguration.isEnabled(CppRuleClasses.USE_HEADER_MODULES);
 
     boolean fake = tempOutputFile != null;
 
@@ -268,11 +274,11 @@
     // before discovering inputs and thus would not declare their inputs properly.
     boolean shouldPruneModules =
         shouldScanIncludes
-            && context.getUseHeaderModules()
+            && useHeaderModules
             && !fake
             && !getActionName().equals(CppCompileAction.CPP_MODULE_COMPILE)
             && featureConfiguration.isEnabled(CppRuleClasses.PRUNE_HEADER_MODULES);
-    if (context.getUseHeaderModules() && !shouldPruneModules) {
+    if (useHeaderModules && !shouldPruneModules) {
       realMandatoryInputsBuilder.addTransitive(context.getTransitiveModules(usePic));
     }
     realMandatoryInputsBuilder.addTransitive(context.getAdditionalInputs());
@@ -299,6 +305,7 @@
           shouldScanIncludes,
           shouldPruneModules,
           usePic,
+          useHeaderModules,
           sourceLabel,
           realMandatoryInputsBuilder.build(),
           outputFile,
@@ -324,6 +331,7 @@
           shouldScanIncludes,
           shouldPruneModules,
           usePic,
+          useHeaderModules,
           sourceLabel,
           realMandatoryInputs,
           outputFile,
@@ -504,14 +512,18 @@
     return this;
   }
 
-  /**
-   * Sets whether the CompileAction should use pic mode.
-   */
+  /** Sets whether the CompileAction should use pic mode. */
   public CppCompileActionBuilder setPicMode(boolean usePic) {
     this.usePic = usePic;
     return this;
   }
 
+  /** Sets whether the CompileAction should use header modules. */
+  public CppCompileActionBuilder setAllowUsingHeaderModules(boolean allowUsingHeaderModules) {
+    this.allowUsingHeaderModules = allowUsingHeaderModules;
+    return this;
+  }
+
   /** Sets the CppSemantics for this compile. */
   public CppCompileActionBuilder setSemantics(CppSemantics semantics) {
     this.cppSemantics = semantics;
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/FakeCppCompileAction.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/FakeCppCompileAction.java
index 4442ac2..97732c4 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/FakeCppCompileAction.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/FakeCppCompileAction.java
@@ -69,6 +69,7 @@
       boolean shouldScanIncludes,
       boolean shouldPruneModules,
       boolean usePic,
+      boolean useHeaderModules,
       Label sourceLabel,
       NestedSet<Artifact> mandatoryInputs,
       Artifact outputFile,
@@ -91,6 +92,7 @@
         shouldScanIncludes,
         shouldPruneModules,
         usePic,
+        useHeaderModules,
         sourceLabel,
         mandatoryInputs,
         outputFile,
