Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 22 additions & 175 deletions ssh/server/channels/session.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package channels

import (
"io"
"strings"
"sync"

gliderssh "github.com/gliderlabs/ssh"
"github.com/shellhub-io/shellhub/ssh/session"
Expand Down Expand Up @@ -74,17 +72,13 @@ const AuthRequestOpenSSHRequest = "[email protected]"
// https://www.ietf.org/archive/id/draft-miller-ssh-agent-11.html#section-4.2
const AuthRequestOpenSSHChannel = "[email protected]"

type DefaultSessionHandlerOptions struct {
RecordURL string
}

// DefaultSessionHandler is the default handler for session's channel.
//
// A session is a remote execution of a program. The program may be a shell, an application, a system command, or some
// built-in subsystem. It may or may not have a tty, and may or may not involve X11 forwarding.
//
// https://www.rfc-editor.org/rfc/rfc4254#section-6
func DefaultSessionHandler(opts DefaultSessionHandlerOptions) gliderssh.ChannelHandler {
func DefaultSessionHandler() gliderssh.ChannelHandler {
return func(_ *gliderssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx gliderssh.Context) {
sess, _ := session.ObtainSession(ctx)

Expand Down Expand Up @@ -132,8 +126,9 @@ func DefaultSessionHandler(opts DefaultSessionHandlerOptions) gliderssh.ChannelH

defer agent.Close()

var wg sync.WaitGroup
go pipe(ctx, sess, client, agent)

// TODO: Add middleware to block a certain type of requests.
for {
select {
case <-ctx.Done():
Expand All @@ -155,19 +150,12 @@ func DefaultSessionHandler(opts DefaultSessionHandlerOptions) gliderssh.ChannelH
// always keeping the prefix "keepalive". So, to maintain the retro compatibility, we check if this
// prefix exists and perform the necessary operations.
case strings.HasPrefix(req.Type, KeepAliveRequestTypePrefix):
wantReply, err := client.SendRequest(KeepAliveRequestType, req.WantReply, req.Payload)
if err != nil {
if _, err := client.SendRequest(KeepAliveRequestType, req.WantReply, req.Payload); err != nil {
logger.Error("failed to send the keepalive request received from agent to client")

return
}

if err := req.Reply(wantReply, nil); err != nil {
logger.WithError(err).Error("failed to send the keepalive response back to agent")

return
}

if err := sess.KeepAlive(); err != nil {
logger.WithError(err).Error("failed to send the API request to inform that the session is open")

Expand All @@ -180,68 +168,35 @@ func DefaultSessionHandler(opts DefaultSessionHandlerOptions) gliderssh.ChannelH
}
}
}
case req, ok := <-clientReqs:
case req, ok := <-agentReqs:
if !ok {
logger.Trace("client requests is closed")
logger.Trace("agent requests is closed")

return
}

logger.Debugf("request from client to agent: %s", req.Type)
logger.Debugf("request from agent to client: %s", req.Type)

ok, err := agent.SendRequest(req.Type, req.WantReply, req.Payload)
ok, err := client.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
logger.WithError(err).Error("failed to send the request from client to agent")
logger.WithError(err).Error("failed to send the request from agent to client")

continue
}

switch req.Type {
case ShellRequestType, ExecRequestType, SubsystemRequestType:
// Once the session has been set up, a program is started at the remote end. The program can be a
// shell, an application program, or a subsystem with a host-independent name. **Only one of these
// requests can succeed per channel.**
//
// https://www.rfc-editor.org/rfc/rfc4254#section-6.5
if sess.Handled && req.Type == ShellRequestType {
logger.Warn("fail to start a new session before ending the previous one")

if err := req.Reply(false, nil); err != nil {
logger.WithError(err).Error("failed to reply the client when data pipe already started")
}

continue
}

if req.WantReply {
if err := req.Reply(ok, nil); err != nil {
logger.WithError(err).Error("failed to reply the client with right response for pipe request type")

return
logger.WithError(err).Error(err)
}
}
case req, ok := <-clientReqs:
if !ok {
logger.Trace("client requests is closed")

logger.Info("session type set")

if req.Type == ShellRequestType && sess.Pty.Term != "" {
if err := sess.Announce(client); err != nil {
logger.WithError(err).Warn("failed to get the namespace announcement")
}
}
return
}

// The server SHOULD NOT halt the execution of the protocol stack when starting a shell or a
// program. All input and output from these SHOULD be redirected to the channel or to the
// encrypted tunnel.
//
// https://www.rfc-editor.org/rfc/rfc4254#section-6.5
wg.Add(1)
go func() {
ch := make(chan bool)
go func() {
<-ch
wg.Done()
}()

pipe(ctx, sess, client, agent, req.Type, opts, ch)
}()
switch req.Type {
case PtyRequestType:
var pty session.Pty

Expand All @@ -251,14 +206,6 @@ func DefaultSessionHandler(opts DefaultSessionHandlerOptions) gliderssh.ChannelH

sess.Pty = pty

if req.WantReply {
// req.Reply(ok, nil) //nolint:errcheck
if err := req.Reply(ok, nil); err != nil {
logger.WithError(err).Error("failed to reply for pty-req")

return
}
}
case WindowChangeRequestType:
var dimensions session.Dimensions

Expand All @@ -268,120 +215,20 @@ func DefaultSessionHandler(opts DefaultSessionHandlerOptions) gliderssh.ChannelH

sess.Pty.Columns = dimensions.Columns
sess.Pty.Rows = dimensions.Rows

if req.WantReply {
if err := req.Reply(ok, nil); err != nil {
logger.Error("failed to reply for window-change")

return
}
}
case AuthRequestOpenSSHRequest:
_, err := agent.SendRequest(AuthRequestOpenSSHRequest, req.WantReply, req.Payload)
if err != nil {
reject(nil, "failed to the auth request to agent")

return
}

req.Reply(true, nil) //nolint:errcheck

gliderssh.SetAgentRequested(ctx)

go func() {
clientConn := ctx.Value(gliderssh.ContextKeyConn).(gossh.Conn)
agentChannels := sess.AgentClient.HandleChannelOpen(AuthRequestOpenSSHChannel)

for {
newAgentChannel, ok := <-agentChannels
if !ok {
reject(nil, "channel for agent forwarding done")

return
}

agentChannel, reqs, err := newAgentChannel.Accept()
if err != nil {
reject(nil, "failed to accept the chanel request from agent on auth request")

return
}

defer agentChannel.Close()
go gossh.DiscardRequests(reqs)

go func() {
clientChannel, reqs, err := clientConn.OpenChannel(AuthRequestOpenSSHChannel, nil)
if err != nil {
reject(nil, "failed to open the auth request channel from agent to client")

return
}

defer clientChannel.Close()
go gossh.DiscardRequests(reqs)

var wg sync.WaitGroup

wg.Add(1)
go func() {
defer agentChannel.CloseWrite() //nolint:errcheck
defer wg.Done()

if _, err := io.Copy(agentChannel, clientChannel); err != nil && err != io.EOF {
logger.WithError(err).Trace("auth agent forwarding coping from client to agent")
}
}()

wg.Add(1)
go func() {
defer clientChannel.CloseWrite() //nolint:errcheck
defer wg.Done()

if _, err := io.Copy(clientChannel, agentChannel); err != nil && err != io.EOF {
logger.WithError(err).Trace("auth agent forwarding coping from agent to client")
}
}()

wg.Wait()
}()
logger.WithError(err).Trace("auth request channel piping done")
}
}()
default:
if req.WantReply {
if err := req.Reply(ok, nil); err != nil {
logger.WithError(err).Error("failed to reply")

return
}
}
}
case req, ok := <-agentReqs:
if !ok {
logger.Trace("agent requests is closed")

return
}

logger.Debugf("request from agent to client: %s", req.Type)

if req.Type == ExitStatusRequest {
wg.Wait()
}
logger.Debugf("request from client to agent: %s", req.Type)

ok, err := client.SendRequest(req.Type, req.WantReply, req.Payload)
ok, err := agent.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
logger.WithError(err).Error("failed to send the request from agent to client")
logger.WithError(err).Error("failed to send the request from client to agent")

continue
}

if req.WantReply {
if err := req.Reply(ok, nil); err != nil {
logger.WithError(err).Error("failed to reply the agent request")

return
logger.WithError(err).Error(err)
}
}
}
Expand Down
53 changes: 30 additions & 23 deletions ssh/server/channels/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
gossh "golang.org/x/crypto/ssh"
)

func pipe(ctx gliderssh.Context, sess *session.Session, client gossh.Channel, agent gossh.Channel, req string, opts DefaultSessionHandlerOptions, ch chan bool) {
func pipe(ctx gliderssh.Context, sess *session.Session, client gossh.Channel, agent gossh.Channel) {
defer func() {
ctx.Lock()
sess.Handled = false
Expand All @@ -30,10 +30,6 @@ func pipe(ctx gliderssh.Context, sess *session.Session, client gossh.Channel, ag
WithFields(log.Fields{"session": sess.UID, "sshid": sess.SSHID}).
Trace("data pipe between client and agent has done")

if err := sess.Type(req); err != nil {
log.WithError(err).Warn("failed to set the session type")
}

wg := new(sync.WaitGroup)
wg.Add(2)

Expand All @@ -43,11 +39,26 @@ func pipe(ctx gliderssh.Context, sess *session.Session, client gossh.Channel, ag
go func() {
defer wg.Done()
defer client.CloseWrite() //nolint:errcheck
defer func() {
ch <- true
}()

if req == ShellRequestType {
// NOTE: As the copy required to record the session seem to be inefficient, if we don't have a record URL
// defined, we use an [io.Copy] for the data piping between agent and client.
recordURL := ctx.Value("RECORD_URL").(string)
if (envs.IsEnterprise() || envs.IsCloud()) && recordURL != "" {
// TODO: Should it be a channel of pointers to [models.SessionRecorded], or just the structure, could deliver a
// better performance?
camera := make(chan *models.SessionRecorded, 100)

go func() {
for {
frame, ok := <-camera
if !ok {
break
}

sess.Record(frame, recordURL) //nolint:errcheck
}
}()

buffer := make([]byte, 1024)
for {
read, err := a.Read(buffer)
Expand Down Expand Up @@ -75,25 +86,21 @@ func pipe(ctx gliderssh.Context, sess *session.Session, client gossh.Channel, ag
break
}

if envs.IsEnterprise() || envs.IsCloud() {
message := string(buffer[:read])

sess.Record(&models.SessionRecorded{ //nolint:errcheck
UID: sess.UID,
Namespace: sess.Lookup["domain"],
Message: message,
Width: int(sess.Pty.Columns),
Height: int(sess.Pty.Rows),
}, opts.RecordURL)
camera <- &models.SessionRecorded{ //nolint:errcheck
UID: sess.UID,
Namespace: sess.Lookup["domain"],
Message: string(buffer[:read]),
Width: int(sess.Pty.Columns),
Height: int(sess.Pty.Rows),
}
}
} else {
if _, err := io.Copy(client, a); err != nil && err != io.EOF {
log.WithError(err).Error("failed on coping data from agent to client")
log.WithError(err).Error("failed on coping data from client to agent")
}

log.Trace("agent channel data copy done")
}

log.Trace("agent channel data copy done")
}()

go func() {
Expand All @@ -103,7 +110,7 @@ func pipe(ctx gliderssh.Context, sess *session.Session, client gossh.Channel, ag
// connection to avoid it be hanged after data flow ends.
if ver, err := semver.NewVersion(sess.Device.Info.Version); ver != nil && err == nil {
// NOTE: We indicate here v0.9.3, but it is not included due the assertion `less than`.
if ver.LessThan(semver.MustParse("v0.9.3")) && req == ExecRequestType {
if ver.LessThan(semver.MustParse("v0.9.3")) {
agent.Close()
} else {
agent.CloseWrite() //nolint:errcheck
Expand Down
7 changes: 2 additions & 5 deletions ssh/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func NewServer(opts *Options, tunnel *httptunnel.Tunnel, cache cache.Cache) *Ser
Addr: ":2222",
ConnCallback: func(ctx gliderssh.Context, conn net.Conn) net.Conn {
ctx.SetValue("conn", conn)
ctx.SetValue("RECORD_URL", opts.RecordURL)

return conn
},
Expand Down Expand Up @@ -88,11 +89,7 @@ func NewServer(opts *Options, tunnel *httptunnel.Tunnel, cache cache.Cache) *Ser
// and the server. SSH channels serve as the infrastructure for executing commands, establishing shell sessions,
// and securely forwarding network services.
ChannelHandlers: map[string]gliderssh.ChannelHandler{
channels.SessionChannel: channels.DefaultSessionHandler(
channels.DefaultSessionHandlerOptions{
RecordURL: opts.RecordURL,
},
),
channels.SessionChannel: channels.DefaultSessionHandler(),
channels.DirectTCPIPChannel: channels.DefaultDirectTCPIPHandler,
},
LocalPortForwardingCallback: func(ctx gliderssh.Context, dhost string, dport uint32) bool {
Expand Down
Loading