前面谈过gRPC的SSL/TLS安全机制,发现设置过程比较复杂:比如证书签名:需要服务端、客户端两头都设置等。想想实际上用JWT会更加便捷,而且更安全和功能强大,因为除JWT的加密签名之外还可以把私密的用户信息放在JWT里加密后在服务端和客户端之间传递。当然,最基本的是通过对JWT的验证机制可以控制客户端对某些功能的使用权限。
通过JWT实现gRPC的函数调用权限管理原理其实很简单:客户端首先从服务端通过身份验证获取JWT,然后在调用服务函数时把这个JWT同时传给服务端进行权限验证。客户端提交身份验证请求返回JWT可以用一个独立的服务函数实现,如下面.proto文件里的GetAuthToken:
代码语言:javascript复制message PBPOSCredential {
string userid = 1;
string password = 2;
}
message PBPOSToken {
string jwt = 1;
}
service SendCommand {
rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};
}
比较棘手的是如何把JWT从客户端传送至服务端,因为gRPC基本上骑劫了Request和Response。其中一个方法是通过Interceptor来截取Request的header即metadata。客户端将JWT写入metadata,服务端从metadata读取JWT。
我们先看看客户端的Interceptor设置和使用:
代码语言:javascript复制 class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
super.start(responseListener, headers)
}
}
}
...
val unsafeChannel = NettyChannelBuilder
.forAddress("192.168.0.189",50051)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))
val securedClient = SendCommandGrpc.blockingStub(securedChannel)
val resp = securedClient.singleResponse(PBPOSCommand())
身份验证请求即JWT获取是不需要Interceptor的,所以要用没有Interceptor的unsafeChannel:
代码语言:javascript复制 //build connection channel
val unsafeChannel = NettyChannelBuilder
.forAddress("192.168.0.189",50051)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
println(s"got jwt: $jwt")
JWT的构建和使用已经在前面的几篇博文里讨论过了:
代码语言:javascript复制package com.datatech.auth
import pdi.jwt._
import org.json4s.native.Json
import org.json4s._
import org.json4s.jackson.JsonMethods._
import pdi.jwt.algorithms._
import scala.util._
object AuthBase {
type UserInfo = Map[String, Any]
case class AuthBase(
algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
secret: String = "OpenSesame",
getUserInfo: (String,String) => Option[UserInfo] = null) {
ctx =>
def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)
def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)
def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)
def authenticateToken(token: String): Option[String] =
algorithm match {
case algo: JwtAsymmetricAlgorithm =>
Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
case true => Some(token)
case _ => None
}
case _ =>
Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
case true => Some(token)
case _ => None
}
}
def getUserInfo(token: String): Option[UserInfo] = {
algorithm match {
case algo: JwtAsymmetricAlgorithm =>
Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) "userinfo").values.asInstanceOf[UserInfo])
case Failure(err) => None
}
case _ =>
Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) "userinfo").values.asInstanceOf[UserInfo])
case Failure(err) => None
}
}
}
def issueJwt(userinfo: UserInfo): String = {
val claims = JwtClaim() Json(DefaultFormats).write(("userinfo", userinfo))
Jwt.encode(claims, secret, algorithm)
}
}
}
服务端Interceptor的构建和设置如下:
代码语言:javascript复制abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {
protected val delegate: Future[Listener[Q]]
private val eventually = delegate.foreach _
override def onComplete(): Unit = eventually { _.onComplete() }
override def onCancel(): Unit = eventually { _.onCancel() }
override def onMessage(message: Q): Unit = eventually { _ onMessage message }
override def onHalfClose(): Unit = eventually { _.onHalfClose() }
override def onReady(): Unit = eventually { _.onReady() }
}
object Keys {
val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
}
class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
override def interceptCall[Q, R](
call: ServerCall[Q, R],
headers: Metadata,
next: ServerCallHandler[Q, R]
): Listener[Q] = {
val prevCtx = Context.current
val jwt = headers.get(Keys.AUTH_META_KEY)
println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")
new FutureListener[Q] {
protected val delegate = Future {
val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
Contexts.interceptCall(nextCtx, call, headers, next)
}
}
}
}
trait gRPCServer {
def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
import actorSys.dispatcher
val server = NettyServerBuilder
.forPort(50051)
.addService(ServerInterceptors.intercept(service,
new AuthorizationInterceptor))
.build
.start
// make sure our server is stopped when jvm is shut down
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = {
server.shutdown()
server.awaitTermination()
}
})
}
}
注意:客户端上传的request-header只能在构建server时接触到,在具体服务函数里是无法调用request-header的,但gRPC又一个结构Context可以在两个地方都能调用。所以,我们可以在构建server时把JWT从header搬到Context里。不过,千万注意这个Context的读写必须在同一个线程里。在服务端的Interceptor里我们把JWT从metadata里读出然后写入Context。在需要权限管理的服务函数里再从Context里读取JWT进行验证:
代码语言:javascript复制 override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
val jwt = AUTH_CTX_KEY.get
println(s"***********$jwt**************")
val optUserInfo = authenticator.getUserInfo(jwt)
val shopid = optUserInfo match {
case Some(m) => m("shopid")
case None => "invalid token!"
}
FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
}
JWT的构建也是一个服务函数:
代码语言:javascript复制 val authenticator = new AuthBase()
.withAlgorithm(JwtAlgorithm.HS256)
.withSecretKey("OpenSesame")
.withUserFunc(getValidUser)
override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
getValidUser(request.userid, request.password) match {
case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
}
}
还需要一个模拟的身份验证服务函数:
代码语言:javascript复制package com.datatech.auth
object MockUserAuthService {
type UserInfo = Map[String,Any]
case class User(username: String, password: String, userInfo: UserInfo)
val validUsers = Seq(User("johnny", "p4ssw0rd",Map("shopid" -> "1101", "userid" -> "101"))
,User("tiger", "secret", Map("shopid" -> "1101" , "userid" -> "102")))
def getValidUser(userid: String, pswd: String): Option[UserInfo] =
validUsers.find(user => user.username == userid && user.password == pswd) match {
case Some(user) => Some(user.userInfo)
case _ => None
}
}
下面是本次示范的源代码:
project/plugins.sbt
代码语言:javascript复制addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.15")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.21")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
libraryDependencies = "com.thesamet.scalapb" %% "compilerplugin" % "0.9.0-M6"
build.sbt
代码语言:javascript复制name := "grpc-jwt"
version := "0.1"
version := "0.1"
scalaVersion := "2.12.8"
scalacOptions = "-Ypartial-unification"
val akkaversion = "2.5.23"
libraryDependencies := Seq(
"com.typesafe.akka" %% "akka-cluster-metrics" % akkaversion,
"com.typesafe.akka" %% "akka-cluster-sharding" % akkaversion,
"com.typesafe.akka" %% "akka-persistence" % akkaversion,
"com.lightbend.akka" %% "akka-stream-alpakka-cassandra" % "1.0.1",
"org.mongodb.scala" %% "mongo-scala-driver" % "2.6.0",
"com.lightbend.akka" %% "akka-stream-alpakka-mongodb" % "1.0.1",
"com.typesafe.akka" %% "akka-persistence-query" % akkaversion,
"com.typesafe.akka" %% "akka-persistence-cassandra" % "0.97",
"com.datastax.cassandra" % "cassandra-driver-core" % "3.6.0",
"com.datastax.cassandra" % "cassandra-driver-extras" % "3.6.0",
"ch.qos.logback" % "logback-classic" % "1.2.3",
"io.monix" %% "monix" % "3.0.0-RC2",
"org.typelevel" %% "cats-core" % "2.0.0-M1",
"io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion,
"io.netty" % "netty-tcnative-boringssl-static" % "2.0.22.Final",
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",
"com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion,
"com.pauldijou" %% "jwt-core" % "3.0.1",
"de.heikoseeberger" %% "akka-http-json4s" % "1.22.0",
"org.json4s" %% "json4s-native" % "3.6.1",
"com.typesafe.akka" %% "akka-http-spray-json" % "10.1.8",
"org.json4s" %% "json4s-jackson" % "3.6.7",
"org.json4s" %% "json4s-ext" % "3.6.7"
)
// (optional) If you need scalapb/scalapb.proto or anything from
// google/protobuf/*.proto
//libraryDependencies = "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf"
PB.targets in Compile := Seq(
scalapb.gen() -> (sourceManaged in Compile).value
)
enablePlugins(JavaAppPackaging)
main/protobuf/posmessages.proto
代码语言:javascript复制syntax = "proto3";
import "google/protobuf/wrappers.proto";
import "google/protobuf/any.proto";
import "scalapb/scalapb.proto";
option (scalapb.options) = {
// use a custom Scala package name
// package_name: "io.ontherocks.introgrpc.demo"
// don't append file name to package
flat_package: true
// generate one Scala file for all messages (services still get their own file)
single_file: true
// add imports to generated file
// useful when extending traits or using custom types
// import: "io.ontherocks.hellogrpc.RockingMessage"
// code to put at the top of generated file
// works only with `single_file: true`
//preamble: "sealed trait SomeSealedTrait"
};
package com.datatech.pos.messages;
message PBVchState { //单据状态
string opr = 1; //收款员
int64 jseq = 2; //begin journal sequence for read-side replay
int32 num = 3; //当前单号
int32 seq = 4; //当前序号
bool void = 5; //取消模式
bool refd = 6; //退款模式
bool susp = 7; //挂单
bool canc = 8; //废单
bool due = 9; //当前余额
string su = 10; //主管编号
string mbr = 11; //会员号
int32 mode = 12; //当前操作流程:0=logOff, 1=LogOn, 2=Payment
}
message PBTxnItem { //交易记录
string txndate = 1; //交易日期
string txntime = 2; //录入时间
string opr = 3; //操作员
int32 num = 4; //销售单号
int32 seq = 5; //交易序号
int32 txntype = 6; //交易类型
int32 salestype = 7; //销售类型
int32 qty = 8; //交易数量
int32 price = 9; //单价(分)
int32 amount = 10; //码洋(分)
int32 disc = 11; //折扣率 (%)
int32 dscamt = 12; //折扣额:负值 net实洋 = amount dscamt
string member = 13; //会员卡号
string code = 14; //编号(商品、卡号...)
string acct = 15; //账号
string dpt = 16; //部类
}
message PBPOSResponse {
int32 sts = 1;
string msg = 2;
PBVchState voucher = 3;
repeated PBTxnItem txnitems = 4;
}
message PBPOSCommand {
string commandname = 1;
string delimitedparams = 2;
}
message PBPOSCredential {
string userid = 1;
string password = 2;
}
message PBPOSToken {
string jwt = 1;
}
service SendCommand {
rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};
}
gRPCServer.scala
代码语言:javascript复制package com.datatech.grpc.server
import io.grpc.ServerServiceDefinition
import io.grpc.netty.NettyServerBuilder
import io.grpc.ServerInterceptors
import scala.concurrent._
import io.grpc.Context
import io.grpc.Contexts
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.ServerInterceptor
import io.grpc.Metadata
import io.grpc.Metadata.Key.of
import io.grpc.Context.key
import io.grpc.ServerCall.Listener
import akka.actor._
abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {
protected val delegate: Future[Listener[Q]]
private val eventually = delegate.foreach _
override def onComplete(): Unit = eventually { _.onComplete() }
override def onCancel(): Unit = eventually { _.onCancel() }
override def onMessage(message: Q): Unit = eventually { _ onMessage message }
override def onHalfClose(): Unit = eventually { _.onHalfClose() }
override def onReady(): Unit = eventually { _.onReady() }
}
object Keys {
val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
}
class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
override def interceptCall[Q, R](
call: ServerCall[Q, R],
headers: Metadata,
next: ServerCallHandler[Q, R]
): Listener[Q] = {
val prevCtx = Context.current
val jwt = headers.get(Keys.AUTH_META_KEY)
println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")
new FutureListener[Q] {
protected val delegate = Future {
val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
Contexts.interceptCall(nextCtx, call, headers, next)
}
}
}
}
trait gRPCServer {
def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
import actorSys.dispatcher
val server = NettyServerBuilder
.forPort(50051)
.addService(ServerInterceptors.intercept(service,
new AuthorizationInterceptor))
.build
.start
// make sure our server is stopped when jvm is shut down
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = {
server.shutdown()
server.awaitTermination()
}
})
}
}
POSServices.scala
代码语言:javascript复制package com.datatech.pos.service
import com.datatech.grpc.server.Keys._
import akka.http.scaladsl.util.FastFuture
import com.datatech.pos.messages._
import com.datatech.grpc.server._
import com.datatech.auth.MockUserAuthService._
import scala.concurrent.Future
import com.datatech.auth.AuthBase._
import pdi.jwt._
import akka.actor._
import io.grpc.stub.StreamObserver
object POSServices extends gRPCServer {
type UserInfo = Map[String, Any]
class POSServices extends SendCommandGrpc.SendCommand {
val authenticator = new AuthBase()
.withAlgorithm(JwtAlgorithm.HS256)
.withSecretKey("OpenSesame")
.withUserFunc(getValidUser)
override def getTxnItems(request: PBPOSCommand, responseObserver: StreamObserver[PBTxnItem]): Unit = ???
override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
val jwt = AUTH_CTX_KEY.get
println(s"***********$jwt**************")
val optUserInfo = authenticator.getUserInfo(jwt)
val shopid = optUserInfo match {
case Some(m) => m("shopid")
case None => "invalid token!"
}
FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
}
override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
getValidUser(request.userid, request.password) match {
case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
}
}
}
def main(args: Array[String]) = {
implicit val system = ActorSystem("grpc-system")
val svc = SendCommandGrpc.bindService(new POSServices, system.dispatcher)
runServer(svc)
}
}
AuthBase.scala
代码语言:javascript复制package com.datatech.auth
import pdi.jwt._
import org.json4s.native.Json
import org.json4s._
import org.json4s.jackson.JsonMethods._
import pdi.jwt.algorithms._
import scala.util._
object AuthBase {
type UserInfo = Map[String, Any]
case class AuthBase(
algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
secret: String = "OpenSesame",
getUserInfo: (String,String) => Option[UserInfo] = null) {
ctx =>
def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)
def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)
def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)
def authenticateToken(token: String): Option[String] =
algorithm match {
case algo: JwtAsymmetricAlgorithm =>
Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
case true => Some(token)
case _ => None
}
case _ =>
Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
case true => Some(token)
case _ => None
}
}
def getUserInfo(token: String): Option[UserInfo] = {
algorithm match {
case algo: JwtAsymmetricAlgorithm =>
Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) "userinfo").values.asInstanceOf[UserInfo])
case Failure(err) => None
}
case _ =>
Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) "userinfo").values.asInstanceOf[UserInfo])
case Failure(err) => None
}
}
}
def issueJwt(userinfo: UserInfo): String = {
val claims = JwtClaim() Json(DefaultFormats).write(("userinfo", userinfo))
Jwt.encode(claims, secret, algorithm)
}
}
}
POSClient.scala
代码语言:javascript复制package com.datatech.pos.client
import com.datatech.pos.messages.{PBPOSCommand, PBPOSCredential, SendCommandGrpc}
import io.grpc.stub.StreamObserver
import io.grpc.netty.{ NegotiationType, NettyChannelBuilder}
import io.grpc.CallOptions
import io.grpc.ClientCall
import io.grpc.ClientInterceptor
import io.grpc.ForwardingClientCall
import io.grpc.Metadata
import io.grpc.Metadata.Key
import io.grpc.MethodDescriptor
import io.grpc.ClientInterceptors
object POSClient {
class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
super.start(responseListener, headers)
}
}
}
def main(args: Array[String]): Unit = {
//build connection channel
val unsafeChannel = NettyChannelBuilder
.forAddress("192.168.0.189",50051)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
println(s"got jwt: $jwt")
val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))
val securedClient = SendCommandGrpc.blockingStub(securedChannel)
val resp = securedClient.singleResponse(PBPOSCommand())
println(s"secured response: $resp")
// wait for async execution
scala.io.StdIn.readLine()
}
}