/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace folly { class AsyncDetachFdCallback { public: virtual ~AsyncDetachFdCallback() = default; virtual void fdDetached( NetworkSocket ns, std::unique_ptr unread) noexcept = 0; virtual void fdDetachFail(const AsyncSocketException& ex) noexcept = 0; }; } // namespace folly #if FOLLY_HAS_LIBURING class IoUringBackend; namespace folly { class AsyncIoUringSocket : public AsyncSocketTransport { public: using Cert = folly::AsyncTransportCertificate; struct Options { Options() : allocateNoBufferPoolBuffer(defaultAllocateNoBufferPoolBuffer), multishotRecv(true) {} static std::unique_ptr defaultAllocateNoBufferPoolBuffer(); folly::Function()> allocateNoBufferPoolBuffer; folly::Optional zeroCopyEnable; bool multishotRecv; }; using UniquePtr = std::unique_ptr; explicit AsyncIoUringSocket( AsyncTransport::UniquePtr other, Options&& options = Options{}); explicit AsyncIoUringSocket(AsyncSocket* sock, Options&& options = Options{}); explicit AsyncIoUringSocket(EventBase* evb, Options&& options = Options{}); explicit AsyncIoUringSocket( EventBase* evb, NetworkSocket ns, Options&& options = Options{}); static bool supports(EventBase* backend); static bool supportsZcRx(EventBase* backend); void connect( AsyncSocket::ConnectCallback* callback, const folly::SocketAddress& address, std::chrono::milliseconds timeout = std::chrono::milliseconds(0), SocketOptionMap const& options = emptySocketOptionMap, const SocketAddress& bindAddr = anyAddress(), const std::string& ifName = std::string()) noexcept; void connect( ConnectCallback* callback, const folly::SocketAddress& address, int timeout, SocketOptionMap const& options, const SocketAddress& bindAddr, const std::string& ifName) noexcept override { connect( callback, address, std::chrono::milliseconds(timeout), options, bindAddr, ifName); } std::chrono::nanoseconds getConnectTime() const { return connectEndTime_ - connectStartTime_; } // AsyncSocketBase EventBase* getEventBase() const override { return evb_; } // AsyncReader void setReadCB(ReadCallback* callback) override; ReadCallback* getReadCallback() const override { return readSqe_->readCallback(); } std::unique_ptr takePreReceivedData() override { return readSqe_->takePreReceivedData(); } // AsyncWriter void write(WriteCallback*, const void*, size_t, WriteFlags = WriteFlags::NONE) override; void writev( WriteCallback*, const iovec*, size_t, WriteFlags = WriteFlags::NONE) override; void writeChain( WriteCallback* callback, std::unique_ptr&& buf, WriteFlags flags) override; bool canZC(std::unique_ptr const& buf) const; // AsyncTransport void close() override; void closeNow() override; void closeWithReset() override; void shutdownWrite() override; void shutdownWriteNow() override; bool good() const override; bool readable() const override { return good(); } bool error() const override; bool hangup() const override; bool connecting() const override { return connectSqe_ && connectSqe_->inFlight(); } void attachEventBase(EventBase*) override; void detachEventBase() override; bool isDetachable() const override; uint32_t getSendTimeout() const override { return static_cast( std::chrono::duration_cast(writeTimeoutTime_) .count()); } void setSendTimeout(uint32_t ms) override; void getLocalAddress(SocketAddress* address) const override; void getPeerAddress(SocketAddress*) const override; void setPreReceivedData(std::unique_ptr data) override; void cacheAddresses() override; /** * @return True iff end of record tracking is enabled */ bool isEorTrackingEnabled() const override { return false; } void setEorTracking(bool) override { // don't support this. // as far as I can see this is only used by AsyncSSLSocket, but TLS1.3 // supercedes this so I think we can ignore it. throw std::runtime_error( "AsyncIoUringSocket::setEorTracking not supported"); } size_t getAppBytesWritten() const override { return getRawBytesWritten(); } size_t getRawBytesWritten() const override { return bytesWritten_; } size_t getAppBytesReceived() const override { return getRawBytesReceived(); } size_t getRawBytesReceived() const override; const AsyncTransport* getWrappedTransport() const override { return nullptr; } // AsyncSocketTransport int setNoDelay(bool noDelay) override; int setSockOpt( int level, int optname, const void* optval, socklen_t optsize) override; std::string getSecurityProtocol() const override { return securityProtocol_; } std::string getApplicationProtocol() const noexcept override { return applicationProtocol_; } NetworkSocket getNetworkSocket() const override { return fd_; } void setSecurityProtocol(std::string s) { securityProtocol_ = std::move(s); } void setApplicationProtocol(std::string s) { applicationProtocol_ = std::move(s); } const folly::AsyncTransportCertificate* getPeerCertificate() const override { return peerCert_.get(); } const folly::AsyncTransportCertificate* getSelfCertificate() const override { return selfCert_.get(); } void dropPeerCertificate() noexcept override { peerCert_.reset(); } void dropSelfCertificate() noexcept override { selfCert_.reset(); } void setPeerCertificate(const std::shared_ptr& peerCert) { peerCert_ = peerCert; } void setSelfCertificate(const std::shared_ptr& selfCert) { selfCert_ = selfCert; } void asyncDetachFd(AsyncDetachFdCallback* callback); bool readSqeInFlight() const { return readSqe_->inFlight(); } bool getTFOSucceded() const override; void enableTFO() override { // No-op if folly does not allow tfo #if FOLLY_ALLOW_TFO VLOG(5) << "AsyncIoUringSocket::enableTFO()"; enableTFO_ = true; #endif } void appendPreReceive(std::unique_ptr iobuf) noexcept; protected: ~AsyncIoUringSocket() override; private: friend class ReadSqe; friend class WriteSqe; void setFd(NetworkSocket ns); void registerFd(); void unregisterFd(); void readProcessSubmit( struct io_uring_sqe* sqe, IoUringBufferProviderBase* bufferProvider, size_t* maxSize, IoUringBufferProviderBase* usedBufferProvider) noexcept; void readCallback( int res, uint32_t flags, size_t maxSize, IoUringBufferProviderBase* bufferProvider) noexcept; void allowReads(); void previousReadDone(); void processWriteQueue() noexcept; void setStateEstablished(); void writeDone() noexcept; void doSubmitWrite() noexcept; void doReSubmitWrite() noexcept; void failAllWrites() noexcept; void submitRead(bool now = false); void processConnectSubmit( struct io_uring_sqe* sqe, sockaddr_storage& storage); void processConnectResult(const io_uring_cqe* cqe); void processConnectTimeout(); void processFastOpenResult(const io_uring_cqe* cqe) noexcept; void startSendTimeout(); void sendTimeoutExpired(); void failWrite(const AsyncSocketException& ex); void readEOF(); void readError(); NetworkSocket takeFd(); bool setZeroCopy(bool enable) override; bool getZeroCopy() const override; void setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func) override; enum class State { None, Connecting, Established, Closed, Error, FastOpen, }; static std::string toString(State s); std::string stateAsString() const { return toString(state_); } struct ReadSqe : IoSqeBase, DelayedDestruction { using UniquePtr = std::unique_ptr; explicit ReadSqe(AsyncIoUringSocket* parent); void processSubmit(struct io_uring_sqe* sqe) noexcept override; void callback(const io_uring_cqe* cqe) noexcept override; void callbackCancelled(const io_uring_cqe* cqe) noexcept override; void setReadCallback(ReadCallback* callback, bool submitNow); ReadCallback* readCallback() const { return readCallback_; } size_t bytesReceived() const { return bytesReceived_; } std::unique_ptr takePreReceivedData(); void appendPreReceive(std::unique_ptr data) noexcept { appendReadData(std::move(data), preReceivedData_); } void destroy() override { parent_ = nullptr; DelayedDestruction::destroy(); } bool waitingForOldEventBaseRead() const; void setOldEventBaseRead(folly::SemiFuture>&& f) { oldEventBaseRead_ = std::move(f); } void attachEventBase(); folly::Optional>> detachEventBase(); private: ~ReadSqe() override = default; void appendReadData( std::unique_ptr data, std::unique_ptr& overflow) noexcept; void sendReadBuf( std::unique_ptr buf, std::unique_ptr& overflow) noexcept; bool readCallbackUseIoBufs() const; void invalidState(ReadCallback* callback); void processOldEventBaseRead(); bool isEOF(const io_uring_cqe* cqe) noexcept; IoUringBufferProviderBase* lastUsedBufferProvider_; ReadCallback* readCallback_ = nullptr; AsyncIoUringSocket* parent_; size_t maxSize_; uint64_t setReadCbCount_{0}; size_t bytesReceived_{0}; std::unique_ptr queuedReceivedData_; std::unique_ptr preReceivedData_; std::unique_ptr tmpBuffer_; bool supportsMultishotRecv_ = false; // todo: this can be per process instead of per socket bool supportsZeroCopyRx_ = false; folly::Optional>> oldEventBaseRead_; std::shared_ptr alive_; }; struct CloseSqe : IoSqeBase { explicit CloseSqe(AsyncIoUringSocket* parent) : IoSqeBase(IoSqeBase::Type::Close), parent_(parent) { setEventBase(parent->evb_); } void processSubmit(struct io_uring_sqe* sqe) noexcept override { parent_->closeProcessSubmit(sqe); } void callback(const io_uring_cqe*) noexcept override { delete this; } void callbackCancelled(const io_uring_cqe*) noexcept override { delete this; } AsyncIoUringSocket* parent_; }; struct write_sqe_tag; using write_sqe_hook = boost::intrusive::list_base_hook>; struct WriteSqe final : IoSqeBase, public write_sqe_hook { explicit WriteSqe( AsyncIoUringSocket* parent, WriteCallback* callback, std::unique_ptr&& buf, WriteFlags flags, bool zc); ~WriteSqe() override { VLOG(5) << "~WriteSqe() " << this; } void processSubmit(struct io_uring_sqe* sqe) noexcept override; void callback(const io_uring_cqe* cqe) noexcept override; void callbackCancelled(const io_uring_cqe* cqe) noexcept override; int sendMsgFlags() const; std::pair< folly::SemiFuture>>, WriteSqe*> detachEventBase(); boost::intrusive::list_member_hook<> member_hook_; AsyncIoUringSocket* parent_; WriteCallback* callback_; std::unique_ptr buf_; WriteFlags flags_; static constexpr size_t kSmallIoVecSize = 16; small_vector iov_; size_t totalLength_; struct msghdr msg_; bool zerocopy_{false}; int refs_ = 1; folly::Function detachedSignal_; }; using WriteSqeList = boost::intrusive::list< WriteSqe, boost::intrusive::base_hook, boost::intrusive::constant_time_size>; class WriteTimeout : public AsyncTimeout { public: explicit WriteTimeout(AsyncIoUringSocket* socket) : AsyncTimeout(socket->evb_), socket_(socket) {} void timeoutExpired() noexcept override { socket_->sendTimeoutExpired(); } private: AsyncIoUringSocket* socket_; }; struct ConnectSqe : IoSqeBase, AsyncTimeout { explicit ConnectSqe(AsyncIoUringSocket* parent) : IoSqeBase(IoSqeBase::Type::Connect), AsyncTimeout(parent->evb_), parent_(parent) { setEventBase(parent->evb_); } void processSubmit(struct io_uring_sqe* sqe) noexcept override { parent_->processConnectSubmit(sqe, addrStorage); } void callback(const io_uring_cqe* cqe) noexcept override { parent_->processConnectResult(cqe); } void callbackCancelled(const io_uring_cqe*) noexcept override { delete this; } void timeoutExpired() noexcept override { if (!cancelled()) { parent_->processConnectTimeout(); } } AsyncIoUringSocket* parent_; sockaddr_storage addrStorage; }; struct FastOpenSqe : IoSqeBase { explicit FastOpenSqe( AsyncIoUringSocket* parent, SocketAddress const& addr, std::unique_ptr initialWrite); void processSubmit(struct io_uring_sqe* sqe) noexcept override; void cleanupMsg() noexcept; void callback(const io_uring_cqe* cqe) noexcept override { cleanupMsg(); parent_->processFastOpenResult(cqe); } void callbackCancelled(const io_uring_cqe*) noexcept override { delete this; } AsyncIoUringSocket* parent_; std::unique_ptr initialWrite; size_t addrLen_; sockaddr_storage addrStorage; }; EventBase* evb_ = nullptr; NetworkSocket fd_; IoUringBackend* backend_ = nullptr; Options options_; mutable SocketAddress localAddress_; mutable SocketAddress peerAddress_; IoUringFdRegistrationRecord* fdRegistered_ = nullptr; int usedFd_ = -1; unsigned int mbFixedFileFlags_ = 0; std::unique_ptr closeSqe_{new CloseSqe(this)}; State state_ = State::None; // read friend struct DetachFdState; ReadSqe::UniquePtr readSqe_; // write std::chrono::milliseconds writeTimeoutTime_{0}; WriteTimeout writeTimeout_{this}; WriteSqe* writeSqeActive_ = nullptr; WriteSqeList writeSqeQueue_; size_t bytesWritten_{0}; // connect std::unique_ptr connectSqe_; AsyncSocket::ConnectCallback* connectCallback_; std::chrono::milliseconds connectTimeout_{0}; std::chrono::steady_clock::time_point connectStartTime_; std::chrono::steady_clock::time_point connectEndTime_; // stopTLS helpers: std::string securityProtocol_; std::string applicationProtocol_; std::shared_ptr selfCert_; std::shared_ptr peerCert_; // shutdown: int shutdownFlags_ = 0; // TCP fast open std::unique_ptr fastOpenSqe_; bool enableTFO_ = false; // detach event base bool isDetaching_ = false; Optional>>> detachedWriteResult_; std::shared_ptr alive_; void closeProcessSubmit(struct io_uring_sqe* sqe); }; } // namespace folly #endif