Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ func (s *ServerCLI) Run(_ *kong.Context) error {

// Orchestrator workflow activities
orchestratorActivities := orchestrator.NewActivities(st, snapshotStore)
w.RegisterActivityWithOptions(orchestratorActivities.ClearFindings, activity.RegisterOptions{Name: orchestrator.ClearFindingsActivityName})
w.RegisterActivityWithOptions(orchestratorActivities.RecordResourceScanResult, activity.RegisterOptions{Name: orchestrator.RecordResourceScanResultActivityName})
if snapshotStore != nil {
w.RegisterActivityWithOptions(orchestratorActivities.CreateSnapshot, activity.RegisterOptions{Name: orchestrator.CreateSnapshotActivityName})
Expand Down Expand Up @@ -479,7 +480,7 @@ func (s *ServerCLI) Run(_ *kong.Context) error {
fmt.Printf(" Scans will run automatically (schedule: %s)\n", s.ScheduleCron)
}
fmt.Println("\n📖 To trigger a scan manually, use the Temporal UI or CLI:")
fmt.Printf(" temporal workflow start --task-queue %s --type %s --input '{}'\n", s.TemporalTaskQueue, orchestrator.OrchestratorWorkflowType)
fmt.Printf(" temporal workflow start --workflow-id %s --task-queue %s --type %s --input '{}'\n", orchestrator.ActiveScanWorkflowID, s.TemporalTaskQueue, orchestrator.OrchestratorWorkflowType)
fmt.Println("\n📖 For more information, see the README.md")
fmt.Println("\nPress Ctrl+C to stop...")

Expand Down
12 changes: 12 additions & 0 deletions pkg/scan/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package scan

import (
"encoding/json"
"errors"
"fmt"
"net/http"

"go.temporal.io/api/serviceerror"

"github.com/block/Version-Guard/pkg/telemetry"
"github.com/block/Version-Guard/pkg/types"
)
Expand Down Expand Up @@ -65,6 +68,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Source: telemetry.ScanSourceHTTP,
})
if err != nil {
if isScanAlreadyRunning(err) {
writeError(w, http.StatusConflict, err.Error())
return
}
writeError(w, http.StatusInternalServerError, err.Error())
return
}
Expand All @@ -83,3 +90,8 @@ func writeJSON(w http.ResponseWriter, status int, body interface{}) {
func writeError(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg})
}

func isScanAlreadyRunning(err error) bool {
var alreadyStarted *serviceerror.WorkflowExecutionAlreadyStarted
return errors.As(err, &alreadyStarted)
}
13 changes: 13 additions & 0 deletions pkg/scan/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.temporal.io/api/serviceerror"

"github.com/block/Version-Guard/pkg/types"
"github.com/block/Version-Guard/pkg/workflow/orchestrator"
Expand Down Expand Up @@ -122,3 +123,15 @@ func TestHandler_TriggerError_Returns500(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "temporal unavailable")
}

func TestHandler_ScanAlreadyRunning_Returns409(t *testing.T) {
mock := &mockStarter{err: serviceerror.NewWorkflowExecutionAlreadyStarted("already running", "request", "run")}
h := newTestHandler(t, mock)

req := httptest.NewRequest(http.MethodPost, "/scan", http.NoBody)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)

assert.Equal(t, http.StatusConflict, rec.Code)
assert.Contains(t, rec.Body.String(), "already running")
}
19 changes: 11 additions & 8 deletions pkg/scan/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/google/uuid"
enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/sdk/client"

"github.com/block/Version-Guard/pkg/telemetry"
Expand Down Expand Up @@ -147,9 +148,11 @@ func (t *Trigger) Run(ctx context.Context, in Input) (res Result, err error) {
workflowID = buildWorkflowID(scanID)

opts := client.StartWorkflowOptions{
ID: workflowID,
TaskQueue: t.taskQueue,
WorkflowExecutionTimeout: defaultExecutionTimeout,
ID: workflowID,
TaskQueue: t.taskQueue,
WorkflowExecutionTimeout: defaultExecutionTimeout,
WorkflowIDConflictPolicy: enumspb.WORKFLOW_ID_CONFLICT_POLICY_FAIL,
WorkflowExecutionErrorWhenAlreadyStarted: true,
}

run, err := t.starter.ExecuteWorkflow(ctx, opts, orchestrator.OrchestratorWorkflow, orchestrator.WorkflowInput{
Expand All @@ -171,9 +174,9 @@ func (t *Trigger) Run(ctx context.Context, in Input) (res Result, err error) {
}, nil
}

// buildWorkflowID produces a workflow ID that is distinguishable from
// scheduled executions. Scheduled runs use the schedule's generated IDs;
// manual runs are prefixed so they are easy to find in Temporal UI/CLI.
func buildWorkflowID(scanID string) string {
return fmt.Sprintf("version-guard-scan-%s", scanID)
// buildWorkflowID returns the singleton orchestrator workflow ID. Temporal
// rejects a new run with this ID while the previous scan is still open, which
// keeps the worker-local findings store safe to use as per-scan scratch space.
func buildWorkflowID(_ string) string {
return orchestrator.ActiveScanWorkflowID
}
13 changes: 7 additions & 6 deletions pkg/scan/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package scan
import (
"context"
"errors"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/sdk/client"

"github.com/block/Version-Guard/pkg/types"
Expand Down Expand Up @@ -50,21 +50,23 @@ func (m *mockStarter) ExecuteWorkflow(_ context.Context, options client.StartWor

func TestTrigger_Run_FullScan(t *testing.T) {
mock := &mockStarter{
run: &mockWorkflowRun{id: "version-guard-scan-abc", runID: "run-1"},
run: &mockWorkflowRun{id: orchestrator.ActiveScanWorkflowID, runID: "run-1"},
}
defaults := []types.ResourceType{"aurora-mysql", "eks"}
tr := NewTriggerWithStarter(mock, "version-guard-orchestrator", defaults)

res, err := tr.Run(context.Background(), Input{ScanID: "abc"})

require.NoError(t, err)
assert.Equal(t, "version-guard-scan-abc", res.WorkflowID)
assert.Equal(t, orchestrator.ActiveScanWorkflowID, res.WorkflowID)
assert.Equal(t, "run-1", res.RunID)
assert.Equal(t, "abc", res.ScanID)

require.True(t, mock.called)
assert.Equal(t, "version-guard-scan-abc", mock.calledOpts.ID)
assert.Equal(t, orchestrator.ActiveScanWorkflowID, mock.calledOpts.ID)
assert.Equal(t, "version-guard-orchestrator", mock.calledOpts.TaskQueue)
assert.Equal(t, enumspb.WORKFLOW_ID_CONFLICT_POLICY_FAIL, mock.calledOpts.WorkflowIDConflictPolicy)
assert.True(t, mock.calledOpts.WorkflowExecutionErrorWhenAlreadyStarted)

require.Len(t, mock.calledArgs, 1)
in, ok := mock.calledArgs[0].(orchestrator.WorkflowInput)
Expand Down Expand Up @@ -115,8 +117,7 @@ func TestTrigger_Run_GeneratesScanIDWhenEmpty(t *testing.T) {

require.NoError(t, err)
assert.NotEmpty(t, res.ScanID, "ScanID should be generated when not provided")
assert.True(t, strings.HasPrefix(mock.calledOpts.ID, "version-guard-scan-"),
"workflow ID should be prefixed")
assert.Equal(t, orchestrator.ActiveScanWorkflowID, mock.calledOpts.ID)
in := mock.calledArgs[0].(orchestrator.WorkflowInput)
assert.Equal(t, res.ScanID, in.ScanID, "generated ScanID should be passed to workflow")
}
Expand Down
16 changes: 14 additions & 2 deletions pkg/schedule/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"time"

enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/sdk/client"
"go.temporal.io/sdk/temporal"

Expand Down Expand Up @@ -77,6 +78,7 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error {
Jitter: cfg.Jitter,
},
Action: &client.ScheduleWorkflowAction{
ID: orchestrator.ActiveScanWorkflowID,
Workflow: orchestrator.OrchestratorWorkflow,
Args: []interface{}{orchestrator.WorkflowInput{
ResourceTypes: cfg.ResourceTypes,
Expand All @@ -86,7 +88,8 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error {
TaskQueue: cfg.TaskQueue,
WorkflowExecutionTimeout: 2 * time.Hour,
},
Paused: cfg.Paused,
Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP,
Paused: cfg.Paused,
}

_, err := m.scheduleClient.Create(ctx, opts)
Expand Down Expand Up @@ -117,8 +120,9 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error {
}
existingCrons := existingSpec.CronExpressions
specMatches := len(existingCrons) == 1 && existingCrons[0] == cfg.CronExpression && existingSpec.Jitter == cfg.Jitter
overlapMatches := desc.Schedule.Policy != nil && desc.Schedule.Policy.Overlap == enumspb.SCHEDULE_OVERLAP_POLICY_SKIP
actionMatches := scheduleActionMatches(desc.Schedule.Action, &cfg)
if specMatches && actionMatches {
if specMatches && overlapMatches && actionMatches {
fmt.Printf(" Schedule %q already configured (cron: %s)\n", cfg.ScheduleID, cfg.CronExpression)
return nil
}
Expand All @@ -140,7 +144,12 @@ func (m *Manager) EnsureSchedule(ctx context.Context, cfg Config) error {
CronExpressions: []string{cfg.CronExpression},
Jitter: cfg.Jitter,
}
if input.Description.Schedule.Policy == nil {
input.Description.Schedule.Policy = &client.SchedulePolicies{}
}
input.Description.Schedule.Policy.Overlap = enumspb.SCHEDULE_OVERLAP_POLICY_SKIP
if action, ok := input.Description.Schedule.Action.(*client.ScheduleWorkflowAction); ok {
action.ID = orchestrator.ActiveScanWorkflowID
action.TaskQueue = cfg.TaskQueue
action.Args = []interface{}{orchestrator.WorkflowInput{
ResourceTypes: cfg.ResourceTypes,
Expand Down Expand Up @@ -175,6 +184,9 @@ func scheduleActionMatches(action client.ScheduleAction, cfg *Config) bool {
if !ok {
return false
}
if wfAction.ID != orchestrator.ActiveScanWorkflowID {
return false
}
if wfAction.TaskQueue != cfg.TaskQueue {
return false
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/schedule/schedule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/sdk/client"
"go.temporal.io/sdk/temporal"

Expand Down Expand Up @@ -129,7 +130,9 @@ func TestEnsureSchedule_CreatesNew(t *testing.T) {
assert.Equal(t, "test-schedule", mock.createOpts.ID)
assert.Equal(t, []string{"0 */6 * * *"}, mock.createOpts.Spec.CronExpressions)
assert.Equal(t, 5*time.Minute, mock.createOpts.Spec.Jitter)
assert.Equal(t, enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, mock.createOpts.Overlap)
action := mock.createOpts.Action.(*client.ScheduleWorkflowAction)
assert.Equal(t, orchestrator.ActiveScanWorkflowID, action.ID)
assert.Equal(t, "test-queue", action.TaskQueue)
assert.Equal(t, 2*time.Hour, action.WorkflowExecutionTimeout)
require.Len(t, action.Args, 1)
Expand All @@ -143,11 +146,13 @@ func TestEnsureSchedule_AlreadyExists_SameCron(t *testing.T) {
id: "test-schedule",
describeOut: &client.ScheduleDescription{
Schedule: client.Schedule{
Policy: &client.SchedulePolicies{Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP},
Spec: &client.ScheduleSpec{
CronExpressions: []string{"0 */6 * * *"},
Jitter: 5 * time.Minute,
},
Action: &client.ScheduleWorkflowAction{
ID: orchestrator.ActiveScanWorkflowID,
TaskQueue: "test-queue",
Args: []interface{}{orchestrator.WorkflowInput{
ResourceTypes: testResourceTypes,
Expand Down Expand Up @@ -227,8 +232,11 @@ func TestEnsureSchedule_AlreadyExists_NewWebhookURL(t *testing.T) {
require.NoError(t, err)
assert.True(t, handle.updateCalled, "Update must be called when EmitterWebhookURL changes")
require.NotNil(t, captured)
require.NotNil(t, captured.Schedule.Policy)
assert.Equal(t, enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, captured.Schedule.Policy.Overlap)
action, ok := captured.Schedule.Action.(*client.ScheduleWorkflowAction)
require.True(t, ok, "action should be a ScheduleWorkflowAction")
assert.Equal(t, orchestrator.ActiveScanWorkflowID, action.ID)
require.Len(t, action.Args, 1)
in, ok := action.Args[0].(orchestrator.WorkflowInput)
require.True(t, ok)
Expand Down
9 changes: 9 additions & 0 deletions pkg/store/memory/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ func NewStore() *Store {
}
}

// ClearFindings removes all findings from memory.
func (s *Store) ClearFindings(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

s.findings = make(map[string]*types.Finding)
return nil
}

// SaveFindings saves or updates findings in memory
func (s *Store) SaveFindings(ctx context.Context, findings []*types.Finding) error {
s.mu.Lock()
Expand Down
16 changes: 16 additions & 0 deletions pkg/store/memory/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ func TestStore_GetFinding(t *testing.T) {
assert.Nil(t, notFound)
}

func TestStore_ClearFindings(t *testing.T) {
ctx := context.Background()
s := NewStore()

require.NoError(t, s.SaveFindings(ctx, []*types.Finding{
{ResourceID: "1", Status: types.StatusRed},
{ResourceID: "2", Status: types.StatusGreen},
}))

require.NoError(t, s.ClearFindings(ctx))

results, err := s.ListFindings(ctx, store.FindingFilters{})
require.NoError(t, err)
assert.Empty(t, results)
}

func TestStore_ListFindings_NoFilters(t *testing.T) {
ctx := context.Background()
s := NewStore()
Expand Down
3 changes: 3 additions & 0 deletions pkg/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (

// Store defines the interface for persisting and retrieving findings
type Store interface {
// ClearFindings removes all findings from the store.
ClearFindings(ctx context.Context) error

// SaveFindings saves or updates findings
SaveFindings(ctx context.Context, findings []*types.Finding) error

Expand Down
12 changes: 12 additions & 0 deletions pkg/workflow/orchestrator/activities.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
// Activity names
const (
CreateSnapshotActivityName = "version-guard.CreateSnapshot"
ClearFindingsActivityName = "version-guard.ClearFindings"
RecordResourceScanResultActivityName = "version-guard.RecordResourceScanResult"
)

Expand Down Expand Up @@ -66,6 +67,12 @@ func NewActivities(
}
}

// ClearFindings removes any findings left in the worker-local scratch store.
func (a *Activities) ClearFindings(ctx context.Context) error {
activity.GetLogger(ctx).Info("Clearing in-memory findings")
return a.Store.ClearFindings(ctx)
}

// CreateSnapshot reads findings directly from the store and persists a snapshot to S3.
// This avoids passing large finding payloads through Temporal activity results,
// which would exceed the 4MB gRPC message limit for large inventories (12K+ resources).
Expand Down Expand Up @@ -122,6 +129,11 @@ func (a *Activities) CreateSnapshot(ctx context.Context, input CreateSnapshotInp
if err != nil {
return nil, err
}
if err := a.Store.ClearFindings(ctx); err != nil {
logger.Warn("Failed to clear in-memory findings after snapshot persisted", "error", err)
} else {
logger.Info("Cleared in-memory findings after snapshot persisted", "snapshotID", snap.SnapshotID)
}

logger.Info("Snapshot created and persisted",
"snapshotID", snap.SnapshotID,
Expand Down
7 changes: 7 additions & 0 deletions pkg/workflow/orchestrator/activities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.temporal.io/sdk/testsuite"

"github.com/block/Version-Guard/pkg/snapshot"
"github.com/block/Version-Guard/pkg/store"
"github.com/block/Version-Guard/pkg/store/memory"
"github.com/block/Version-Guard/pkg/types"
)
Expand Down Expand Up @@ -123,6 +124,9 @@ func TestActivities_CreateSnapshot_HappyPath(t *testing.T) {
assert.Equal(t, "v4", fakeSnap.saved.Version)
assert.Equal(t, int64(60), fakeSnap.saved.ScanDurationSec)
assert.Equal(t, 3, fakeSnap.saved.Summary.TotalResources)
findings, err := st.ListFindings(context.Background(), store.FindingFilters{})
require.NoError(t, err)
assert.Empty(t, findings, "CreateSnapshot should clear the in-memory scratch store after persisting")
}

func TestActivities_CreateSnapshot_PersistFailureReturnsError(t *testing.T) {
Expand All @@ -144,6 +148,9 @@ func TestActivities_CreateSnapshot_PersistFailureReturnsError(t *testing.T) {
})
require.Error(t, err)
assert.Contains(t, err.Error(), "s3 went down")
finding, getErr := st.GetFinding(context.Background(), "r1")
require.NoError(t, getErr)
assert.NotNil(t, finding, "failed snapshots should leave findings available for activity retry")
}

func TestActivities_CreateSnapshot_EmptyFindings(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions pkg/workflow/orchestrator/orchestrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ func newOrchestratorEnv(t *testing.T) *testsuite.TestWorkflowEnvironment {
suite := &testsuite.WorkflowTestSuite{}
env := suite.NewTestWorkflowEnvironment()
env.RegisterWorkflow(OrchestratorWorkflow)
env.RegisterActivityWithOptions(
func(_ context.Context) error {
return nil
},
activity.RegisterOptions{Name: ClearFindingsActivityName},
)
env.RegisterActivityWithOptions(
func(_ context.Context, _ RecordResourceScanResultInput) error {
return nil
Expand Down
Loading
Loading