sgx-aa/main.go

474 lines
11 KiB
Go

package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"fmt"
"log"
"net"
"net/http"
"os"
"regexp"
"strings"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/redis/go-redis/v9"
"mellium.im/xmlstream"
"mellium.im/xmpp"
"mellium.im/xmpp/component"
"mellium.im/xmpp/form"
"mellium.im/xmpp/jid"
"mellium.im/xmpp/stanza"
)
type User struct {
tel string
password string
}
func (u *User) Token() string {
sha := sha256.New()
sha.Write([]byte(u.password))
sha.Write([]byte(u.tel))
return hex.EncodeToString(sha.Sum(nil))
}
type Message struct {
ID string
Dest string
Body string
}
var urlPrefix string
var componentDomain string
var rclient *redis.Client
var xmppSession *xmpp.Session
var provider Provider
func healthcheckHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "OK")
}
func srrHandler(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
stanzaId, exist := r.PostForm["stanzaId"]
if !exist || len(stanzaId) != 1 {
http.Error(w, "ERR: bad stanzaId", 400)
return
}
oa, exist := r.Form["oa"]
if !exist {
http.Error(w, "ERR: oa not present", 400)
return
}
inbound, err := rclient.LRange(
context.Background(),
"sgx_aa_inbound-"+oa[0],
0,
1,
).Result()
if err != nil || len(inbound) < 2 || inbound[1] != mux.Vars(r)["token"] {
http.Error(w, "ERR: oa lookup failed", 404)
return
}
to, err := jid.Parse(inbound[0])
if err != nil {
log.Printf("Invalid JID: %s %s\n", inbound[0], err)
http.Error(w, "ERR: target invalid", 500)
return
}
da, exist := r.PostForm["da"]
if !exist || len(da) != 1 {
http.Error(w, "ERR: bad da", 400)
return
}
from, err := jid.New(da[0], componentDomain, "sms")
if err != nil {
log.Printf("Invalid From: %s %s\n", da[0], err)
http.Error(w, "ERR: invalid da", 400)
return
}
code, ok := r.Form["code"]
if ok && code[0] == "1" {
// Received
err = xmppSession.Send(context.Background(), stanza.Message{
Type: stanza.ChatMessage,
ID: uuid.New().String(),
To: to,
From: from,
}.Wrap(
xmlstream.Wrap(
nil,
xml.StartElement{
Name: xml.Name{Space: "urn:xmpp:receipts", Local: "received"},
Attr: []xml.Attr{xml.Attr{xml.Name{Local: "id"}, stanzaId[0]}},
},
),
))
if err != nil {
log.Printf("Send failed: %s %s\n", to, err)
http.Error(w, "ERR: send failed", 500)
return
}
}
if ok && (code[0] == "2" || code[0] == "16") {
// Rejected
err = xmppSession.Send(context.Background(), stanza.Message{
Type: stanza.ErrorMessage,
ID: stanzaId[0],
To: to,
From: from,
}.Wrap(
stanza.Error{
Type: stanza.Cancel,
Condition: stanza.InternalServerError,
Text: map[string]string{"en": "Rejected by SMSC"},
}.TokenReader(),
))
if err != nil {
log.Printf("Send failed: %s %s\n", to, err)
http.Error(w, "ERR: send failed", 500)
return
}
}
fmt.Fprintln(w, "OK: done")
}
func inboundSMSHandler(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
tel, exist := r.PostForm["da"]
if !exist || len(tel) != 1 {
http.Error(w, "ERR: bad da", 400)
return
}
inbound, err := rclient.LRange(
context.Background(),
"sgx_aa_inbound-"+tel[0],
0,
1,
).Result()
if err != nil || len(inbound) < 2 || inbound[1] != mux.Vars(r)["token"] {
http.Error(w, "ERR: target not found", 404)
return
}
to, err := jid.Parse(inbound[0])
if err != nil {
log.Printf("Invalid JID: %s %s\n", inbound[0], err)
http.Error(w, "ERR: target invalid", 500)
return
}
sender, exist := r.PostForm["oa"]
if !exist || len(sender) != 1 {
http.Error(w, "ERR: bad oa", 400)
return
}
from, err := jid.New(sender[0], componentDomain, "sms")
if err != nil {
log.Printf("Invalid From: %s %s\n", sender[0], err)
http.Error(w, "ERR: target oa", 400)
return
}
body, exist := r.PostForm["ud"]
if !exist || len(body) != 1 {
http.Error(w, "ERR: bad ud", 400)
return
}
err = xmppSession.Send(context.Background(), stanza.Message{
Type: stanza.ChatMessage,
ID: uuid.New().String(),
To: to,
From: from,
}.Wrap(
xmlstream.Wrap(
xmlstream.Token(xml.CharData(body[0])),
xml.StartElement{Name: xml.Name{Local: "body"}},
),
))
if err != nil {
log.Printf("Send failed: %s %s\n", to, err)
http.Error(w, "ERR: send failed", 500)
return
}
fmt.Fprintln(w, "OK: done")
}
type MessageBody struct {
stanza.Message
Body string `xml:"body"`
}
func handleMessage(d *xml.Decoder, start *xml.StartElement) error {
msg := MessageBody{}
err := d.DecodeElement(&msg, start)
if err != nil || msg.Body == "" {
return nil
}
creds, err := rclient.LRange(
context.Background(),
"sgx_aa_creds-"+msg.From.Bare().String(),
0,
1,
).Result()
if err != nil || len(creds) < 2 {
log.Printf("No creds found for %s\n", msg.From)
xmppSession.Send(context.Background(), stanza.Message{
Type: stanza.ErrorMessage,
ID: msg.ID,
To: msg.From,
From: msg.To,
}.Wrap(
stanza.Error{
Type: stanza.Cancel,
Condition: stanza.Forbidden,
Text: map[string]string{"en": "Not registered"},
}.TokenReader(),
))
return nil
}
err = provider.SendMessage(
User{creds[0], creds[1]},
Message{
ID: msg.ID,
Body: msg.Body,
Dest: msg.To.Localpart(),
},
)
if err != nil {
log.Printf("SendMessage error: %s\n", err)
xmppSession.Send(context.Background(), stanza.Message{
Type: stanza.ErrorMessage,
ID: msg.ID,
To: msg.From,
From: msg.To,
}.Wrap(
stanza.Error{
Type: stanza.Cancel,
Condition: stanza.InternalServerError,
Text: map[string]string{"en": err.Error()},
}.TokenReader(),
))
return nil
}
return nil
}
type RegQuery struct {
Form *form.Data `xml:"jabber:x:data x,omitempty"`
}
type IQ struct {
stanza.IQ
RegQuery *RegQuery `xml:"jabber:iq:register query,omitempty"`
}
func sendRegistrationForm(iq *IQ) {
xmppSession.Send(context.Background(), iq.Result(
xmlstream.Wrap(
form.New(
form.Title("A&A Setup"),
form.Instructions(fmt.Sprintf("Set your number to deliver SMS to %s/<sha256 of password+tel>/sms", urlPrefix)),
form.Hidden("FORM_TYPE", form.Value("jabber:iq:register")),
form.Text("tel", form.Label("Phone number"), form.Desc("This is the phone number as shown on the control pages for your VoIP number in full international format with no spaces.")),
form.TextPrivate("password", form.Label("Password"), form.Desc("The corresponding outgoing password as set in the control pages for your VoIP number.")),
).TokenReader(),
xml.StartElement{Name: xml.Name{Space: "jabber:iq:register", Local: "query"}},
),
))
}
func handleRegistrationForm(iq *IQ) {
tel, ok := iq.RegQuery.Form.GetString("tel")
if !ok {
xmppSession.Send(context.Background(), iq.Error(stanza.Error{
Type: stanza.Cancel,
Condition: stanza.BadRequest,
Text: map[string]string{"en": "No tel found"},
}))
return
}
matched, err := regexp.Match("^\\+\\d+$", []byte(tel))
if !matched || err != nil {
xmppSession.Send(context.Background(), iq.Error(stanza.Error{
Type: stanza.Cancel,
Condition: stanza.BadRequest,
Text: map[string]string{"en": "Invalid tel"},
}))
return
}
password, ok := iq.RegQuery.Form.GetString("password")
if !ok {
xmppSession.Send(context.Background(), iq.Error(stanza.Error{
Type: stanza.Cancel,
Condition: stanza.BadRequest,
Text: map[string]string{"en": "No password found"},
}))
return
}
user := User{tel, password}
jid := iq.From.Bare().String()
oldinbound, _ := rclient.LRange(
context.Background(),
"sgx_aa_inbound-"+jid,
0,
1,
).Result()
if len(oldinbound) > 1 && oldinbound[0] == jid && oldinbound[1] == user.Token() {
// No change
xmppSession.Send(context.Background(), iq.Result(nil))
return
}
if len(oldinbound) > 1 && oldinbound[0] != jid {
xmppSession.Send(context.Background(), iq.Error(stanza.Error{
Type: stanza.Cancel,
Condition: stanza.Conflict,
Text: map[string]string{"en": "Another user exists for " + tel},
}))
return
}
oldcreds, _ := rclient.LRange(
context.Background(),
"sgx_aa_creds-"+jid,
0,
1,
).Result()
_, err = rclient.TxPipelined(context.Background(), func(pipe redis.Pipeliner) error {
if len(oldcreds) > 1 {
pipe.Del(context.Background(), "sgx_aa_inbound-"+oldcreds[0])
}
pipe.Del(context.Background(), "sgx_aa_inbound-"+tel)
pipe.RPush(
context.Background(),
"sgx_aa_inbound-"+tel,
jid,
user.Token(),
)
pipe.Del(context.Background(), "sgx_aa_creds-"+jid)
pipe.RPush(
context.Background(),
"sgx_aa_creds-"+jid,
tel,
password,
)
return nil
})
if err != nil {
xmppSession.Send(context.Background(), iq.Error(stanza.Error{
Type: stanza.Cancel,
Condition: stanza.InternalServerError,
Text: map[string]string{"en": err.Error()},
}))
return
}
xmppSession.Send(context.Background(), iq.Result(nil))
}
func main() {
urlPrefix = os.Args[6]
router := mux.NewRouter()
router.HandleFunc("/healthcheck", healthcheckHandler)
router.HandleFunc("/{token}/sms/srr", srrHandler)
router.HandleFunc("/{token}/sms", inboundSMSHandler)
opt, err := redis.ParseURL(os.Getenv("REDIS_URL"))
if err != nil {
log.Fatal(err)
}
rclient = redis.NewClient(opt)
componentDomain = os.Args[1]
addr := os.Args[3]
conn, err := net.Dial("tcp", addr)
if err != nil {
log.Fatalf("Failed XMPP dial: %s", err)
}
xmppSession, err = xmpp.NewSession(
context.Background(),
jid.MustParse(strings.Split(os.Args[3], ":")[0]),
jid.MustParse(componentDomain),
conn,
xmpp.Secure,
component.Negotiator(jid.MustParse(componentDomain), []byte(os.Args[2]), false),
)
if err != nil {
log.Fatalf("Failed setting up XMPP Session: %s", err)
}
provider, err = NewAndrewsArnoldProvider(WithReportHandler(urlPrefix))
if err != nil {
log.Fatalf("Failed initializing provider: %s", err)
}
go func() {
err := xmppSession.Serve(xmpp.HandlerFunc(func(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
d := xml.NewTokenDecoder(xmlstream.MultiReader(xmlstream.Token(*start), t))
if _, err := d.Token(); err != nil {
return err
}
if start.Name.Local == "message" {
return handleMessage(d, start)
}
if start.Name.Local == "iq" {
iq := IQ{}
err := d.DecodeElement(&iq, start)
if err != nil {
// Not a valid iq?
return nil
}
if iq.Type == stanza.ResultIQ || iq.Type == stanza.ErrorIQ {
// Don't reply to these
return nil
}
if iq.Type == stanza.GetIQ && iq.RegQuery != nil {
sendRegistrationForm(&iq)
return nil
}
if iq.Type == stanza.SetIQ && iq.RegQuery != nil && iq.RegQuery.Form != nil {
handleRegistrationForm(&iq)
return nil
}
xmppSession.Send(context.Background(), iq.Error(stanza.Error{
Type: stanza.Cancel,
Condition: stanza.ServiceUnavailable,
}))
return nil
}
return nil
}))
log.Fatalf("XMPP session error: %s\n", err)
}()
unixListener, err := net.Listen(os.Args[4], os.Args[5])
if err != nil {
log.Fatal(err)
}
log.Fatal(http.Serve(unixListener, router))
}