From f5f67458d87d53c0219e913c9f179a3897c3a306 Mon Sep 17 00:00:00 2001 From: Ben Apprederisse Date: Wed, 10 Jun 2026 14:25:40 -0700 Subject: [PATCH] Treat findings memory as per-scan scratch Amp-Thread-ID: https://ampcode.com/threads/T-019eb352-0ee8-70dd-a627-ae4034b7607c Co-authored-by: Amp --- cmd/server/main.go | 3 ++- pkg/scan/handler.go | 12 +++++++++ pkg/scan/handler_test.go | 13 +++++++++ pkg/scan/scan.go | 19 +++++++------ pkg/scan/scan_test.go | 13 ++++----- pkg/schedule/schedule.go | 16 +++++++++-- pkg/schedule/schedule_test.go | 8 ++++++ pkg/store/memory/store.go | 9 +++++++ pkg/store/memory/store_test.go | 16 +++++++++++ pkg/store/store.go | 3 +++ pkg/workflow/orchestrator/activities.go | 12 +++++++++ pkg/workflow/orchestrator/activities_test.go | 7 +++++ .../orchestrator/orchestrator_test.go | 6 +++++ pkg/workflow/orchestrator/workflow.go | 27 +++++++++++++++++++ 14 files changed, 147 insertions(+), 17 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index b44a0e5..a14ff69 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -451,6 +451,7 @@ func (s *ServerCLI) Run(_ *kong.Context) error { // Orchestrator workflow activities orchestratorActivities := orchestrator.NewActivities(st, snapshotStore) + w.RegisterActivityWithOptions(orchestratorActivities.ClearFindings, activity.RegisterOptions{Name: orchestrator.ClearFindingsActivityName}) w.RegisterActivityWithOptions(orchestratorActivities.RecordResourceScanResult, activity.RegisterOptions{Name: orchestrator.RecordResourceScanResultActivityName}) if snapshotStore != nil { w.RegisterActivityWithOptions(orchestratorActivities.CreateSnapshot, activity.RegisterOptions{Name: orchestrator.CreateSnapshotActivityName}) @@ -479,7 +480,7 @@ func (s *ServerCLI) Run(_ *kong.Context) error { fmt.Printf(" Scans will run automatically (schedule: %s)\n", s.ScheduleCron) } fmt.Println("\nšŸ“– To trigger a scan manually, use the Temporal UI or CLI:") - fmt.Printf(" temporal workflow start --task-queue %s --type %s --input '{}'\n", s.TemporalTaskQueue, orchestrator.OrchestratorWorkflowType) + fmt.Printf(" temporal workflow start --workflow-id %s --task-queue %s --type %s --input '{}'\n", orchestrator.ActiveScanWorkflowID, s.TemporalTaskQueue, orchestrator.OrchestratorWorkflowType) fmt.Println("\nšŸ“– For more information, see the README.md") fmt.Println("\nPress Ctrl+C to stop...") diff --git a/pkg/scan/handler.go b/pkg/scan/handler.go index 9504d99..76ddd69 100644 --- a/pkg/scan/handler.go +++ b/pkg/scan/handler.go @@ -2,9 +2,12 @@ package scan import ( "encoding/json" + "errors" "fmt" "net/http" + "go.temporal.io/api/serviceerror" + "github.com/block/Version-Guard/pkg/telemetry" "github.com/block/Version-Guard/pkg/types" ) @@ -65,6 +68,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Source: telemetry.ScanSourceHTTP, }) if err != nil { + if isScanAlreadyRunning(err) { + writeError(w, http.StatusConflict, err.Error()) + return + } writeError(w, http.StatusInternalServerError, err.Error()) return } @@ -83,3 +90,8 @@ func writeJSON(w http.ResponseWriter, status int, body interface{}) { func writeError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } + +func isScanAlreadyRunning(err error) bool { + var alreadyStarted *serviceerror.WorkflowExecutionAlreadyStarted + return errors.As(err, &alreadyStarted) +} diff --git a/pkg/scan/handler_test.go b/pkg/scan/handler_test.go index 83c4d17..2c1ffb1 100644 --- a/pkg/scan/handler_test.go +++ b/pkg/scan/handler_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.temporal.io/api/serviceerror" "github.com/block/Version-Guard/pkg/types" "github.com/block/Version-Guard/pkg/workflow/orchestrator" @@ -122,3 +123,15 @@ func TestHandler_TriggerError_Returns500(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, rec.Body.String(), "temporal unavailable") } + +func TestHandler_ScanAlreadyRunning_Returns409(t *testing.T) { + mock := &mockStarter{err: serviceerror.NewWorkflowExecutionAlreadyStarted("already running", "request", "run")} + h := newTestHandler(t, mock) + + req := httptest.NewRequest(http.MethodPost, "/scan", http.NoBody) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusConflict, rec.Code) + assert.Contains(t, rec.Body.String(), "already running") +} diff --git a/pkg/scan/scan.go b/pkg/scan/scan.go index 8376c16..32a06f8 100644 --- a/pkg/scan/scan.go +++ b/pkg/scan/scan.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/uuid" + enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/client" "github.com/block/Version-Guard/pkg/telemetry" @@ -147,9 +148,11 @@ func (t *Trigger) Run(ctx context.Context, in Input) (res Result, err error) { workflowID = buildWorkflowID(scanID) opts := client.StartWorkflowOptions{ - ID: workflowID, - TaskQueue: t.taskQueue, - WorkflowExecutionTimeout: defaultExecutionTimeout, + ID: workflowID, + TaskQueue: t.taskQueue, + WorkflowExecutionTimeout: defaultExecutionTimeout, + WorkflowIDConflictPolicy: enumspb.WORKFLOW_ID_CONFLICT_POLICY_FAIL, + WorkflowExecutionErrorWhenAlreadyStarted: true, } run, err := t.starter.ExecuteWorkflow(ctx, opts, orchestrator.OrchestratorWorkflow, orchestrator.WorkflowInput{ @@ -171,9 +174,9 @@ func (t *Trigger) Run(ctx context.Context, in Input) (res Result, err error) { }, nil } -// buildWorkflowID produces a workflow ID that is distinguishable from -// scheduled executions. Scheduled runs use the schedule's generated IDs; -// manual runs are prefixed so they are easy to find in Temporal UI/CLI. -func buildWorkflowID(scanID string) string { - return fmt.Sprintf("version-guard-scan-%s", scanID) +// buildWorkflowID returns the singleton orchestrator workflow ID. Temporal +// rejects a new run with this ID while the previous scan is still open, which +// keeps the worker-local findings store safe to use as per-scan scratch space. +func buildWorkflowID(_ string) string { + return orchestrator.ActiveScanWorkflowID } diff --git a/pkg/scan/scan_test.go b/pkg/scan/scan_test.go index 1040827..bc2c93b 100644 --- a/pkg/scan/scan_test.go +++ b/pkg/scan/scan_test.go @@ -3,11 +3,11 @@ package scan import ( "context" "errors" - "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/client" "github.com/block/Version-Guard/pkg/types" @@ -50,7 +50,7 @@ func (m *mockStarter) ExecuteWorkflow(_ context.Context, options client.StartWor func TestTrigger_Run_FullScan(t *testing.T) { mock := &mockStarter{ - run: &mockWorkflowRun{id: "version-guard-scan-abc", runID: "run-1"}, + run: &mockWorkflowRun{id: orchestrator.ActiveScanWorkflowID, runID: "run-1"}, } defaults := []types.ResourceType{"aurora-mysql", "eks"} tr := NewTriggerWithStarter(mock, "version-guard-orchestrator", defaults) @@ -58,13 +58,15 @@ func TestTrigger_Run_FullScan(t *testing.T) { res, err := tr.Run(context.Background(), Input{ScanID: "abc"}) require.NoError(t, err) - assert.Equal(t, "version-guard-scan-abc", res.WorkflowID) + assert.Equal(t, orchestrator.ActiveScanWorkflowID, res.WorkflowID) assert.Equal(t, "run-1", res.RunID) assert.Equal(t, "abc", res.ScanID) require.True(t, mock.called) - assert.Equal(t, "version-guard-scan-abc", mock.calledOpts.ID) + assert.Equal(t, orchestrator.ActiveScanWorkflowID, mock.calledOpts.ID) assert.Equal(t, "version-guard-orchestrator", mock.calledOpts.TaskQueue) + assert.Equal(t, enumspb.WORKFLOW_ID_CONFLICT_POLICY_FAIL, mock.calledOpts.WorkflowIDConflictPolicy) + assert.True(t, mock.calledOpts.WorkflowExecutionErrorWhenAlreadyStarted) require.Len(t, mock.calledArgs, 1) in, ok := mock.calledArgs[0].(orchestrator.WorkflowInput) @@ -115,8 +117,7 @@ func TestTrigger_Run_GeneratesScanIDWhenEmpty(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, res.ScanID, "ScanID should be generated when not provided") - assert.True(t, strings.HasPrefix(mock.calledOpts.ID, "version-guard-scan-"), - "workflow ID should be prefixed") + assert.Equal(t, orchestrator.ActiveScanWorkflowID, mock.calledOpts.ID) in := mock.calledArgs[0].(orchestrator.WorkflowInput) assert.Equal(t, res.ScanID, in.ScanID, "generated ScanID should be passed to workflow") } diff --git a/pkg/schedule/schedule.go b/pkg/schedule/schedule.go index 4f6f4c6..ead7e34 100644 --- a/pkg/schedule/schedule.go +++ b/pkg/schedule/schedule.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/client" "go.temporal.io/sdk/temporal" @@ -77,6 +78,7 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error { Jitter: cfg.Jitter, }, Action: &client.ScheduleWorkflowAction{ + ID: orchestrator.ActiveScanWorkflowID, Workflow: orchestrator.OrchestratorWorkflow, Args: []interface{}{orchestrator.WorkflowInput{ ResourceTypes: cfg.ResourceTypes, @@ -86,7 +88,8 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error { TaskQueue: cfg.TaskQueue, WorkflowExecutionTimeout: 2 * time.Hour, }, - Paused: cfg.Paused, + Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, + Paused: cfg.Paused, } _, err := m.scheduleClient.Create(ctx, opts) @@ -117,8 +120,9 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error { } existingCrons := existingSpec.CronExpressions specMatches := len(existingCrons) == 1 && existingCrons[0] == cfg.CronExpression && existingSpec.Jitter == cfg.Jitter + overlapMatches := desc.Schedule.Policy != nil && desc.Schedule.Policy.Overlap == enumspb.SCHEDULE_OVERLAP_POLICY_SKIP actionMatches := scheduleActionMatches(desc.Schedule.Action, &cfg) - if specMatches && actionMatches { + if specMatches && overlapMatches && actionMatches { fmt.Printf(" Schedule %q already configured (cron: %s)\n", cfg.ScheduleID, cfg.CronExpression) return nil } @@ -140,7 +144,12 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error { CronExpressions: []string{cfg.CronExpression}, Jitter: cfg.Jitter, } + if input.Description.Schedule.Policy == nil { + input.Description.Schedule.Policy = &client.SchedulePolicies{} + } + input.Description.Schedule.Policy.Overlap = enumspb.SCHEDULE_OVERLAP_POLICY_SKIP if action, ok := input.Description.Schedule.Action.(*client.ScheduleWorkflowAction); ok { + action.ID = orchestrator.ActiveScanWorkflowID action.TaskQueue = cfg.TaskQueue action.Args = []interface{}{orchestrator.WorkflowInput{ ResourceTypes: cfg.ResourceTypes, @@ -175,6 +184,9 @@ func scheduleActionMatches(action client.ScheduleAction, cfg *Config) bool { if !ok { return false } + if wfAction.ID != orchestrator.ActiveScanWorkflowID { + return false + } if wfAction.TaskQueue != cfg.TaskQueue { return false } diff --git a/pkg/schedule/schedule_test.go b/pkg/schedule/schedule_test.go index c20fa34..ae66069 100644 --- a/pkg/schedule/schedule_test.go +++ b/pkg/schedule/schedule_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/client" "go.temporal.io/sdk/temporal" @@ -129,7 +130,9 @@ func TestEnsureSchedule_CreatesNew(t *testing.T) { assert.Equal(t, "test-schedule", mock.createOpts.ID) assert.Equal(t, []string{"0 */6 * * *"}, mock.createOpts.Spec.CronExpressions) assert.Equal(t, 5*time.Minute, mock.createOpts.Spec.Jitter) + assert.Equal(t, enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, mock.createOpts.Overlap) action := mock.createOpts.Action.(*client.ScheduleWorkflowAction) + assert.Equal(t, orchestrator.ActiveScanWorkflowID, action.ID) assert.Equal(t, "test-queue", action.TaskQueue) assert.Equal(t, 2*time.Hour, action.WorkflowExecutionTimeout) require.Len(t, action.Args, 1) @@ -143,11 +146,13 @@ func TestEnsureSchedule_AlreadyExists_SameCron(t *testing.T) { id: "test-schedule", describeOut: &client.ScheduleDescription{ Schedule: client.Schedule{ + Policy: &client.SchedulePolicies{Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP}, Spec: &client.ScheduleSpec{ CronExpressions: []string{"0 */6 * * *"}, Jitter: 5 * time.Minute, }, Action: &client.ScheduleWorkflowAction{ + ID: orchestrator.ActiveScanWorkflowID, TaskQueue: "test-queue", Args: []interface{}{orchestrator.WorkflowInput{ ResourceTypes: testResourceTypes, @@ -227,8 +232,11 @@ func TestEnsureSchedule_AlreadyExists_NewWebhookURL(t *testing.T) { require.NoError(t, err) assert.True(t, handle.updateCalled, "Update must be called when EmitterWebhookURL changes") require.NotNil(t, captured) + require.NotNil(t, captured.Schedule.Policy) + assert.Equal(t, enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, captured.Schedule.Policy.Overlap) action, ok := captured.Schedule.Action.(*client.ScheduleWorkflowAction) require.True(t, ok, "action should be a ScheduleWorkflowAction") + assert.Equal(t, orchestrator.ActiveScanWorkflowID, action.ID) require.Len(t, action.Args, 1) in, ok := action.Args[0].(orchestrator.WorkflowInput) require.True(t, ok) diff --git a/pkg/store/memory/store.go b/pkg/store/memory/store.go index aa1d7f7..affbc1e 100644 --- a/pkg/store/memory/store.go +++ b/pkg/store/memory/store.go @@ -25,6 +25,15 @@ func NewStore() *Store { } } +// ClearFindings removes all findings from memory. +func (s *Store) ClearFindings(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.findings = make(map[string]*types.Finding) + return nil +} + // SaveFindings saves or updates findings in memory func (s *Store) SaveFindings(ctx context.Context, findings []*types.Finding) error { s.mu.Lock() diff --git a/pkg/store/memory/store_test.go b/pkg/store/memory/store_test.go index 676fd96..a9f16c8 100644 --- a/pkg/store/memory/store_test.go +++ b/pkg/store/memory/store_test.go @@ -70,6 +70,22 @@ func TestStore_GetFinding(t *testing.T) { assert.Nil(t, notFound) } +func TestStore_ClearFindings(t *testing.T) { + ctx := context.Background() + s := NewStore() + + require.NoError(t, s.SaveFindings(ctx, []*types.Finding{ + {ResourceID: "1", Status: types.StatusRed}, + {ResourceID: "2", Status: types.StatusGreen}, + })) + + require.NoError(t, s.ClearFindings(ctx)) + + results, err := s.ListFindings(ctx, store.FindingFilters{}) + require.NoError(t, err) + assert.Empty(t, results) +} + func TestStore_ListFindings_NoFilters(t *testing.T) { ctx := context.Background() s := NewStore() diff --git a/pkg/store/store.go b/pkg/store/store.go index 51cd6e2..1d73dc6 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -9,6 +9,9 @@ import ( // Store defines the interface for persisting and retrieving findings type Store interface { + // ClearFindings removes all findings from the store. + ClearFindings(ctx context.Context) error + // SaveFindings saves or updates findings SaveFindings(ctx context.Context, findings []*types.Finding) error diff --git a/pkg/workflow/orchestrator/activities.go b/pkg/workflow/orchestrator/activities.go index c3c2e85..07a1999 100644 --- a/pkg/workflow/orchestrator/activities.go +++ b/pkg/workflow/orchestrator/activities.go @@ -16,6 +16,7 @@ import ( // Activity names const ( CreateSnapshotActivityName = "version-guard.CreateSnapshot" + ClearFindingsActivityName = "version-guard.ClearFindings" RecordResourceScanResultActivityName = "version-guard.RecordResourceScanResult" ) @@ -66,6 +67,12 @@ func NewActivities( } } +// ClearFindings removes any findings left in the worker-local scratch store. +func (a *Activities) ClearFindings(ctx context.Context) error { + activity.GetLogger(ctx).Info("Clearing in-memory findings") + return a.Store.ClearFindings(ctx) +} + // CreateSnapshot reads findings directly from the store and persists a snapshot to S3. // This avoids passing large finding payloads through Temporal activity results, // which would exceed the 4MB gRPC message limit for large inventories (12K+ resources). @@ -122,6 +129,11 @@ func (a *Activities) CreateSnapshot(ctx context.Context, input CreateSnapshotInp if err != nil { return nil, err } + if err := a.Store.ClearFindings(ctx); err != nil { + logger.Warn("Failed to clear in-memory findings after snapshot persisted", "error", err) + } else { + logger.Info("Cleared in-memory findings after snapshot persisted", "snapshotID", snap.SnapshotID) + } logger.Info("Snapshot created and persisted", "snapshotID", snap.SnapshotID, diff --git a/pkg/workflow/orchestrator/activities_test.go b/pkg/workflow/orchestrator/activities_test.go index 81ec3d6..bc267af 100644 --- a/pkg/workflow/orchestrator/activities_test.go +++ b/pkg/workflow/orchestrator/activities_test.go @@ -11,6 +11,7 @@ import ( "go.temporal.io/sdk/testsuite" "github.com/block/Version-Guard/pkg/snapshot" + "github.com/block/Version-Guard/pkg/store" "github.com/block/Version-Guard/pkg/store/memory" "github.com/block/Version-Guard/pkg/types" ) @@ -123,6 +124,9 @@ func TestActivities_CreateSnapshot_HappyPath(t *testing.T) { assert.Equal(t, "v4", fakeSnap.saved.Version) assert.Equal(t, int64(60), fakeSnap.saved.ScanDurationSec) assert.Equal(t, 3, fakeSnap.saved.Summary.TotalResources) + findings, err := st.ListFindings(context.Background(), store.FindingFilters{}) + require.NoError(t, err) + assert.Empty(t, findings, "CreateSnapshot should clear the in-memory scratch store after persisting") } func TestActivities_CreateSnapshot_PersistFailureReturnsError(t *testing.T) { @@ -144,6 +148,9 @@ func TestActivities_CreateSnapshot_PersistFailureReturnsError(t *testing.T) { }) require.Error(t, err) assert.Contains(t, err.Error(), "s3 went down") + finding, getErr := st.GetFinding(context.Background(), "r1") + require.NoError(t, getErr) + assert.NotNil(t, finding, "failed snapshots should leave findings available for activity retry") } func TestActivities_CreateSnapshot_EmptyFindings(t *testing.T) { diff --git a/pkg/workflow/orchestrator/orchestrator_test.go b/pkg/workflow/orchestrator/orchestrator_test.go index 154bc47..dab6518 100644 --- a/pkg/workflow/orchestrator/orchestrator_test.go +++ b/pkg/workflow/orchestrator/orchestrator_test.go @@ -25,6 +25,12 @@ func newOrchestratorEnv(t *testing.T) *testsuite.TestWorkflowEnvironment { suite := &testsuite.WorkflowTestSuite{} env := suite.NewTestWorkflowEnvironment() env.RegisterWorkflow(OrchestratorWorkflow) + env.RegisterActivityWithOptions( + func(_ context.Context) error { + return nil + }, + activity.RegisterOptions{Name: ClearFindingsActivityName}, + ) env.RegisterActivityWithOptions( func(_ context.Context, _ RecordResourceScanResultInput) error { return nil diff --git a/pkg/workflow/orchestrator/workflow.go b/pkg/workflow/orchestrator/workflow.go index 15197e2..6e22827 100644 --- a/pkg/workflow/orchestrator/workflow.go +++ b/pkg/workflow/orchestrator/workflow.go @@ -23,6 +23,7 @@ var ErrNoResourceTypes = fmt.Errorf("orchestrator: WorkflowInput.ResourceTypes i const ( OrchestratorWorkflowType = "VersionGuardOrchestratorWorkflow" TaskQueueName = "version-guard-orchestrator" + ActiveScanWorkflowID = "version-guard-active-scan" ScanScopeFull = "full" ScanScopeTargeted = "targeted" @@ -120,6 +121,32 @@ func OrchestratorWorkflow(ctx workflow.Context, input WorkflowInput) (*WorkflowO return nil, ErrNoResourceTypes } + // The worker-local findings store is scan scratch space only. Clear + // leftovers from any previous failed or interrupted scan before child + // detection workflows repopulate it. + if workflow.GetVersion(ctx, "clear-findings-at-scan-start", workflow.DefaultVersion, 1) == 1 { + if err := workflow.ExecuteActivity( + workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 10 * time.Second, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: time.Second, + BackoffCoefficient: 2.0, + MaximumInterval: 10 * time.Second, + MaximumAttempts: 3, + }, + }), + ClearFindingsActivityName, + ).Get(ctx, nil); err != nil { + logger.Error("Failed to clear in-memory findings before scan", + "event", "scan_workflow_failed", + "scanID", input.ScanID, + "workflowID", info.WorkflowExecution.ID, + "runID", info.WorkflowExecution.RunID, + "error", err) + return nil, err + } + } + // Retry policy for child workflows retryPolicy := &temporal.RetryPolicy{ InitialInterval: time.Second,